From e2e92b63d59279e1ad38687660f4bd9f5daf7c2c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 10 Dec 2019 11:21:55 +0100 Subject: [PATCH] load model checkpoint on cpu, set 'r' for all models with gradual training enabled for all --- server/synthesizer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/synthesizer.py b/server/synthesizer.py index 14de9411..d8852a3e 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -53,15 +53,14 @@ class Synthesizer(object): num_speakers = 0 self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) # load model state - map_location = None if use_cuda else torch.device('cpu') - cp = torch.load(tts_checkpoint, map_location=map_location) + cp = torch.load(tts_checkpoint, map_location=torch.device('cpu')) # load the model self.tts_model.load_state_dict(cp['model']) if use_cuda: self.tts_model.cuda() self.tts_model.eval() self.tts_model.decoder.max_decoder_steps = 3000 - if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]: + if 'r' in cp: self.tts_model.decoder.set_r(cp['r']) def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):