From 69f525f17de591e5adfa8dd31b302a3e2a409dd7 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 17 Jul 2020 11:37:33 +0200 Subject: [PATCH] load pwgan models --- TTS/vocoder/utils/generic_utils.py | 53 ++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index bef2b35b..9626c1ce 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -67,14 +67,34 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) + if c.generator_model in 'parallel_wavegan_generator': + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=c.generator_model_params['num_res_blocks'], + stacks=c.generator_model_params['stacks'], + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=c.audio['num_mels'], + aux_context_window=c['conv_pad'], + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_conditional_features=True, + upsample_factors=c.generator_model_params['upsample_factors']) return model def setup_discriminator(c): print(" > Discriminator Model: {}".format(c.discriminator_model)) - MyModel = importlib.import_module('TTS.vocoder.models.' + - c.discriminator_model.lower()) - MyModel = getattr(MyModel, to_camel(c.discriminator_model)) + if 'parallel_wavegan' in c.discriminator_model: + MyModel = importlib.import_module('TTS.vocoder.models.parallel_wavegan_discriminator') + else: + MyModel = importlib.import_module('TTS.vocoder.models.' + + c.discriminator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) if c.discriminator_model in 'random_window_discriminator': model = MyModel( cond_channels=c.audio['num_mels'], @@ -95,6 +115,33 @@ def setup_discriminator(c): max_channels=c.discriminator_model_params['max_channels'], downsample_factors=c. discriminator_model_params['downsample_factors']) + if c.discriminator_model == 'residual_parallel_wavegan_discriminator': + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params['num_layers'], + stacks=c.discriminator_model_params['stacks'], + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ) + if c.discriminator_model == 'parallel_wavegan_discriminator': + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params['num_layers'], + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True + ) return model