diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 8aa84d34..bded4cd8 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -499,3 +499,10 @@ class WaveRNN(nn.Module): unfolded[start:end] += y[i] return unfolded + + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training