mirror of https://github.com/coqui-ai/TTS.git
improve robustness of defining wavernn in config file
This commit is contained in:
parent
24ffa9e9f6
commit
f063545325
|
@ -71,10 +71,10 @@ def setup_generator(c):
|
||||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||||
c.generator_model.lower())
|
c.generator_model.lower())
|
||||||
# this is to preserve the WaveRNN class name (instead of Wavernn)
|
# this is to preserve the WaveRNN class name (instead of Wavernn)
|
||||||
if c.generator_model != 'WaveRNN':
|
if c.generator_model.lower() == 'wavernn':
|
||||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
MyModel = getattr(MyModel, 'WaveRNN')
|
||||||
else:
|
else:
|
||||||
MyModel = getattr(MyModel, c.generator_model)
|
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||||
if c.generator_model.lower() in 'wavernn':
|
if c.generator_model.lower() in 'wavernn':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
rnn_dims=c.wavernn_model_params['rnn_dims'],
|
rnn_dims=c.wavernn_model_params['rnn_dims'],
|
||||||
|
|
Loading…
Reference in New Issue