Only use embedded model files if they're not overriden by CLI flags

This commit is contained in:
Reuben Morais 2020-02-04 11:16:48 +01:00
parent 9d669d1024
commit 57e7c1de08
1 changed files with 20 additions and 15 deletions

View File

@ -24,20 +24,32 @@ def create_argparser():
return parser
config = None
synthesizer = None
embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar')
config_file = os.path.join(embedded_model_folder, 'config.json')
if os.path.isfile(checkpoint_file) and os.path.isfile(config_file):
# Use default config with embedded model files
config = create_argparser().parse_args([])
config.tts_checkpoint = checkpoint_file
config.tts_config = config_file
synthesizer = Synthesizer(config)
# Default options with embedded model files
if os.path.isfile(checkpoint_file):
default_tts_checkpoint = checkpoint_file
else:
default_tts_checkpoint = None
if os.path.isfile(config_file):
default_tts_config = config_file
else:
default_tts_config = None
args = create_argparser().parse_args()
# If these were not specified in the CLI args, use default values
if not args.tts_checkpoint:
args.tts_checkpoint = default_tts_checkpoint
if not args.tts_config:
args.tts_config = default_tts_config
synthesizer = Synthesizer(args)
app = Flask(__name__)
@ -55,11 +67,4 @@ def tts():
if __name__ == '__main__':
args = create_argparser().parse_args()
# Setup synthesizer from CLI args if they're specified or no embedded model
# is present.
if not config or not synthesizer or args.tts_checkpoint or args.tts_config:
synthesizer = Synthesizer(args)
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
app.run(debug=args.debug, host='0.0.0.0', port=args.port)