add model params to config

This commit is contained in:
sanjaesc 2020-10-25 09:47:04 +01:00
parent 4a989e3ceb
commit 80f5e39e56
1 changed files with 9 additions and 9 deletions

View File

@ -47,18 +47,18 @@ def setup_wavernn(c):
MyModel = importlib.import_module("TTS.vocoder.models.wavernn") MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
MyModel = getattr(MyModel, "WaveRNN") MyModel = getattr(MyModel, "WaveRNN")
model = MyModel( model = MyModel(
rnn_dims=512, rnn_dims=c.wavernn_model_params['rnn_dims'],
fc_dims=512, fc_dims=c.wavernn_model_params['fc_dims'],
mode=c.mode, mode=c.mode,
mulaw=c.mulaw, mulaw=c.mulaw,
pad=c.padding, pad=c.padding,
use_aux_net=c.use_aux_net, use_aux_net=c.wavernn_model_params['use_aux_net'],
use_upsample_net=c.use_upsample_net, use_upsample_net=c.wavernn_model_params['use_upsample_net'],
upsample_factors=c.upsample_factors, upsample_factors=c.wavernn_model_params['upsample_factors'],
feat_dims=80, feat_dims=c.audio['num_mels'],
compute_dims=128, compute_dims=c.wavernn_model_params['compute_dims'],
res_out_dims=128, res_out_dims=c.wavernn_model_params['res_out_dims'],
res_blocks=10, num_res_blocks=c.wavernn_model_params['num_res_blocks'],
hop_length=c.audio["hop_length"], hop_length=c.audio["hop_length"],
sample_rate=c.audio["sample_rate"], sample_rate=c.audio["sample_rate"],
) )