From 57e7c1de08c527dbd97dddfe9804d7de969739cd Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 4 Feb 2020 11:16:48 +0100 Subject: [PATCH] Only use embedded model files if they're not overriden by CLI flags --- server/server.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/server/server.py b/server/server.py index d40e2427..3be66f9e 100644 --- a/server/server.py +++ b/server/server.py @@ -24,20 +24,32 @@ def create_argparser(): return parser -config = None synthesizer = None embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar') config_file = os.path.join(embedded_model_folder, 'config.json') -if os.path.isfile(checkpoint_file) and os.path.isfile(config_file): - # Use default config with embedded model files - config = create_argparser().parse_args([]) - config.tts_checkpoint = checkpoint_file - config.tts_config = config_file - synthesizer = Synthesizer(config) +# Default options with embedded model files +if os.path.isfile(checkpoint_file): + default_tts_checkpoint = checkpoint_file +else: + default_tts_checkpoint = None +if os.path.isfile(config_file): + default_tts_config = config_file +else: + default_tts_config = None + +args = create_argparser().parse_args() + +# If these were not specified in the CLI args, use default values +if not args.tts_checkpoint: + args.tts_checkpoint = default_tts_checkpoint +if not args.tts_config: + args.tts_config = default_tts_config + +synthesizer = Synthesizer(args) app = Flask(__name__) @@ -55,11 +67,4 @@ def tts(): if __name__ == '__main__': - args = create_argparser().parse_args() - - # Setup synthesizer from CLI args if they're specified or no embedded model - # is present. - if not config or not synthesizer or args.tts_checkpoint or args.tts_config: - synthesizer = Synthesizer(args) - - app.run(debug=config.debug, host='0.0.0.0', port=config.port) + app.run(debug=args.debug, host='0.0.0.0', port=args.port)