diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index c16fa1ae..f9fbba52 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -47,18 +47,18 @@ def setup_wavernn(c): MyModel = importlib.import_module("TTS.vocoder.models.wavernn") MyModel = getattr(MyModel, "WaveRNN") model = MyModel( - rnn_dims=512, - fc_dims=512, + rnn_dims=c.wavernn_model_params['rnn_dims'], + fc_dims=c.wavernn_model_params['fc_dims'], mode=c.mode, mulaw=c.mulaw, pad=c.padding, - use_aux_net=c.use_aux_net, - use_upsample_net=c.use_upsample_net, - upsample_factors=c.upsample_factors, - feat_dims=80, - compute_dims=128, - res_out_dims=128, - res_blocks=10, + use_aux_net=c.wavernn_model_params['use_aux_net'], + use_upsample_net=c.wavernn_model_params['use_upsample_net'], + upsample_factors=c.wavernn_model_params['upsample_factors'], + feat_dims=c.audio['num_mels'], + compute_dims=c.wavernn_model_params['compute_dims'], + res_out_dims=c.wavernn_model_params['res_out_dims'], + num_res_blocks=c.wavernn_model_params['num_res_blocks'], hop_length=c.audio["hop_length"], sample_rate=c.audio["sample_rate"], )