mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #327 from reuben/server-fix-2
Server fix without breaking uWSGI
This commit is contained in:
commit
48e4baf434
|
@ -55,8 +55,11 @@ def tts():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if not config or not synthesizer:
|
|
||||||
args = create_argparser().parse_args()
|
args = create_argparser().parse_args()
|
||||||
|
|
||||||
|
# Setup synthesizer from CLI args if they're specified or no embedded model
|
||||||
|
# is present.
|
||||||
|
if not config or not synthesizer or args.tts_checkpoint or args.tts_config:
|
||||||
synthesizer = Synthesizer(args)
|
synthesizer = Synthesizer(args)
|
||||||
|
|
||||||
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
|
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
|
||||||
|
|
|
@ -53,15 +53,14 @@ class Synthesizer(object):
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||||
# load model state
|
# load model state
|
||||||
map_location = None if use_cuda else torch.device('cpu')
|
cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
|
||||||
cp = torch.load(tts_checkpoint, map_location=map_location)
|
|
||||||
# load the model
|
# load the model
|
||||||
self.tts_model.load_state_dict(cp['model'])
|
self.tts_model.load_state_dict(cp['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
self.tts_model.eval()
|
self.tts_model.eval()
|
||||||
self.tts_model.decoder.max_decoder_steps = 3000
|
self.tts_model.decoder.max_decoder_steps = 3000
|
||||||
if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]:
|
if 'r' in cp:
|
||||||
self.tts_model.decoder.set_r(cp['r'])
|
self.tts_model.decoder.set_r(cp['r'])
|
||||||
|
|
||||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
||||||
|
|
Loading…
Reference in New Issue