mirror of https://github.com/coqui-ai/TTS.git
improve robustness of defining wavernn in config file
This commit is contained in:
parent
24ffa9e9f6
commit
f063545325
|
@ -71,10 +71,10 @@ def setup_generator(c):
|
|||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.generator_model.lower())
|
||||
# this is to preserve the WaveRNN class name (instead of Wavernn)
|
||||
if c.generator_model != 'WaveRNN':
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
if c.generator_model.lower() == 'wavernn':
|
||||
MyModel = getattr(MyModel, 'WaveRNN')
|
||||
else:
|
||||
MyModel = getattr(MyModel, c.generator_model)
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
if c.generator_model.lower() in 'wavernn':
|
||||
model = MyModel(
|
||||
rnn_dims=c.wavernn_model_params['rnn_dims'],
|
||||
|
|
Loading…
Reference in New Issue