diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 05ceba6b..0d532063 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -71,10 +71,10 @@ def setup_generator(c): MyModel = importlib.import_module('TTS.vocoder.models.' + c.generator_model.lower()) # this is to preserve the WaveRNN class name (instead of Wavernn) - if c.generator_model != 'WaveRNN': - MyModel = getattr(MyModel, to_camel(c.generator_model)) + if c.generator_model.lower() == 'wavernn': + MyModel = getattr(MyModel, 'WaveRNN') else: - MyModel = getattr(MyModel, c.generator_model) + MyModel = getattr(MyModel, to_camel(c.generator_model)) if c.generator_model.lower() in 'wavernn': model = MyModel( rnn_dims=c.wavernn_model_params['rnn_dims'],