mirror of https://github.com/coqui-ai/TTS.git
load_checkpoint for hifigan and no_grad for inference
This commit is contained in:
parent
de3a04f104
commit
241e968df1
|
@ -159,6 +159,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
|
@ -173,3 +174,11 @@ class HifiganGenerator(torch.nn.Module):
|
|||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
|
Loading…
Reference in New Issue