Merge pull request #327 from reuben/server-fix-2

Server fix without breaking uWSGI
This commit is contained in:
Eren Gölge 2019-12-13 09:01:59 +01:00 committed by GitHub
commit 48e4baf434
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -55,8 +55,11 @@ def tts():
if __name__ == '__main__': if __name__ == '__main__':
if not config or not synthesizer: args = create_argparser().parse_args()
args = create_argparser().parse_args()
# Setup synthesizer from CLI args if they're specified or no embedded model
# is present.
if not config or not synthesizer or args.tts_checkpoint or args.tts_config:
synthesizer = Synthesizer(args) synthesizer = Synthesizer(args)
app.run(debug=config.debug, host='0.0.0.0', port=config.port) app.run(debug=config.debug, host='0.0.0.0', port=config.port)

View File

@ -53,15 +53,14 @@ class Synthesizer(object):
num_speakers = 0 num_speakers = 0
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
# load model state # load model state
map_location = None if use_cuda else torch.device('cpu') cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
cp = torch.load(tts_checkpoint, map_location=map_location)
# load the model # load the model
self.tts_model.load_state_dict(cp['model']) self.tts_model.load_state_dict(cp['model'])
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
self.tts_model.eval() self.tts_model.eval()
self.tts_model.decoder.max_decoder_steps = 3000 self.tts_model.decoder.max_decoder_steps = 3000
if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]: if 'r' in cp:
self.tts_model.decoder.set_r(cp['r']) self.tts_model.decoder.set_r(cp['r'])
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):