From f0635453256fd48091d26a474d6edfd300280429 Mon Sep 17 00:00:00 2001 From: Branislav Gerazov Date: Fri, 5 Feb 2021 13:26:33 +0100 Subject: [PATCH] improve robustness of defining wavernn in config file --- TTS/vocoder/utils/generic_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 05ceba6b..0d532063 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -71,10 +71,10 @@ def setup_generator(c): MyModel = importlib.import_module('TTS.vocoder.models.' + c.generator_model.lower()) # this is to preserve the WaveRNN class name (instead of Wavernn) - if c.generator_model != 'WaveRNN': - MyModel = getattr(MyModel, to_camel(c.generator_model)) + if c.generator_model.lower() == 'wavernn': + MyModel = getattr(MyModel, 'WaveRNN') else: - MyModel = getattr(MyModel, c.generator_model) + MyModel = getattr(MyModel, to_camel(c.generator_model)) if c.generator_model.lower() in 'wavernn': model = MyModel( rnn_dims=c.wavernn_model_params['rnn_dims'],