diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index f9fbba52..d0eb0657 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -70,7 +70,7 @@ def setup_generator(c): MyModel = importlib.import_module('TTS.vocoder.models.' + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model in 'melgan_generator': + if c.generator_model.lower() in 'melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=1, @@ -81,7 +81,7 @@ def setup_generator(c): num_res_blocks=c.generator_model_params['num_res_blocks']) if c.generator_model in 'melgan_fb_generator': pass - if c.generator_model in 'multiband_melgan_generator': + if c.generator_model.lower() in 'multiband_melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=4, @@ -90,7 +90,7 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) - if c.generator_model in 'fullband_melgan_generator': + if c.generator_model.lower() in 'fullband_melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=1, @@ -99,7 +99,7 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) - if c.generator_model in 'parallel_wavegan_generator': + if c.generator_model.lower() in 'parallel_wavegan_generator': model = MyModel( in_channels=1, out_channels=1, @@ -114,6 +114,16 @@ def setup_generator(c): bias=True, use_weight_norm=True, upsample_factors=c.generator_model_params['upsample_factors']) + if c.generator_model.lower() in 'wavegrad': + model = MyModel( + in_channels=c['audio']['num_mels'], + out_channels=1, + x_conv_channels=c['model_params']['x_conv_channels'], + c_conv_channels=c['model_params']['c_conv_channels'], + dblock_out_channels=c['model_params']['dblock_out_channels'], + ublock_out_channels=c['model_params']['ublock_out_channels'], + upsample_factors=c['model_params']['upsample_factors'], + upsample_dilations=c['model_params']['upsample_dilations']) return model