From 9addfabc430ba6956d4f796d08fc2bd6fd10eac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 21 Jan 2021 15:31:13 +0100 Subject: [PATCH] wavernn load_checkpoint function --- TTS/vocoder/models/wavernn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 8aa84d34..bded4cd8 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -499,3 +499,10 @@ class WaveRNN(nn.Module): unfolded[start:end] += y[i] return unfolded + + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training