improve robustness of defining wavernn in config file

This commit is contained in:
Branislav Gerazov 2021-02-05 13:26:33 +01:00
parent 24ffa9e9f6
commit f063545325
1 changed files with 3 additions and 3 deletions

View File

@ -71,10 +71,10 @@ def setup_generator(c):
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.generator_model.lower())
# this is to preserve the WaveRNN class name (instead of Wavernn)
if c.generator_model != 'WaveRNN':
MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model.lower() == 'wavernn':
MyModel = getattr(MyModel, 'WaveRNN')
else:
MyModel = getattr(MyModel, c.generator_model)
MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model.lower() in 'wavernn':
model = MyModel(
rnn_dims=c.wavernn_model_params['rnn_dims'],