From e2e92b63d59279e1ad38687660f4bd9f5daf7c2c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 10 Dec 2019 11:21:55 +0100 Subject: [PATCH 1/2] load model checkpoint on cpu, set 'r' for all models with gradual training enabled for all --- server/synthesizer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/synthesizer.py b/server/synthesizer.py index 14de9411..d8852a3e 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -53,15 +53,14 @@ class Synthesizer(object): num_speakers = 0 self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) # load model state - map_location = None if use_cuda else torch.device('cpu') - cp = torch.load(tts_checkpoint, map_location=map_location) + cp = torch.load(tts_checkpoint, map_location=torch.device('cpu')) # load the model self.tts_model.load_state_dict(cp['model']) if use_cuda: self.tts_model.cuda() self.tts_model.eval() self.tts_model.decoder.max_decoder_steps = 3000 - if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]: + if 'r' in cp: self.tts_model.decoder.set_r(cp['r']) def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): From 856e87a6a88c823dc34f28fd5e7fddaeafbc0ba6 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Wed, 11 Dec 2019 14:31:59 +0100 Subject: [PATCH 2/2] Override checkpoint and config with CLI args if present --- server/server.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/server.py b/server/server.py index 2831c754..d40e2427 100644 --- a/server/server.py +++ b/server/server.py @@ -55,8 +55,11 @@ def tts(): if __name__ == '__main__': - if not config or not synthesizer: - args = create_argparser().parse_args() + args = create_argparser().parse_args() + + # Setup synthesizer from CLI args if they're specified or no embedded model + # is present. + if not config or not synthesizer or args.tts_checkpoint or args.tts_config: synthesizer = Synthesizer(args) app.run(debug=config.debug, host='0.0.0.0', port=config.port)