diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 44b0edfa..1da726c6 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -159,6 +159,7 @@ class HifiganGenerator(torch.nn.Module): x = torch.tanh(x) return x + @torch.no_grad() def inference(self, c): c = c.to(self.conv_pre.weight.device) c = torch.nn.functional.pad( @@ -173,3 +174,11 @@ class HifiganGenerator(torch.nn.Module): l.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) + + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm()