mirror of https://github.com/coqui-ai/TTS.git
Only use embedded model files if they're not overriden by CLI flags
This commit is contained in:
parent
9d669d1024
commit
57e7c1de08
|
@ -24,20 +24,32 @@ def create_argparser():
|
|||
return parser
|
||||
|
||||
|
||||
config = None
|
||||
synthesizer = None
|
||||
|
||||
embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
|
||||
checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar')
|
||||
config_file = os.path.join(embedded_model_folder, 'config.json')
|
||||
|
||||
if os.path.isfile(checkpoint_file) and os.path.isfile(config_file):
|
||||
# Use default config with embedded model files
|
||||
config = create_argparser().parse_args([])
|
||||
config.tts_checkpoint = checkpoint_file
|
||||
config.tts_config = config_file
|
||||
synthesizer = Synthesizer(config)
|
||||
# Default options with embedded model files
|
||||
if os.path.isfile(checkpoint_file):
|
||||
default_tts_checkpoint = checkpoint_file
|
||||
else:
|
||||
default_tts_checkpoint = None
|
||||
|
||||
if os.path.isfile(config_file):
|
||||
default_tts_config = config_file
|
||||
else:
|
||||
default_tts_config = None
|
||||
|
||||
args = create_argparser().parse_args()
|
||||
|
||||
# If these were not specified in the CLI args, use default values
|
||||
if not args.tts_checkpoint:
|
||||
args.tts_checkpoint = default_tts_checkpoint
|
||||
if not args.tts_config:
|
||||
args.tts_config = default_tts_config
|
||||
|
||||
synthesizer = Synthesizer(args)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
@ -55,11 +67,4 @@ def tts():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = create_argparser().parse_args()
|
||||
|
||||
# Setup synthesizer from CLI args if they're specified or no embedded model
|
||||
# is present.
|
||||
if not config or not synthesizer or args.tts_checkpoint or args.tts_config:
|
||||
synthesizer = Synthesizer(args)
|
||||
|
||||
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
|
||||
app.run(debug=args.debug, host='0.0.0.0', port=args.port)
|
||||
|
|
Loading…
Reference in New Issue