load_checkpoint for hifigan and no_grad for inference

This commit is contained in:
Eren Gölge 2021-04-06 11:02:50 +02:00
parent de3a04f104
commit 241e968df1
1 changed files with 9 additions and 0 deletions

View File

@ -159,6 +159,7 @@ class HifiganGenerator(torch.nn.Module):
x = torch.tanh(x)
return x
@torch.no_grad()
def inference(self, c):
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
@ -173,3 +174,11 @@ class HifiganGenerator(torch.nn.Module):
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()