wavernn load_checkpoint function

This commit is contained in:
Eren Gölge 2021-01-21 15:31:13 +01:00
parent 50fee59a2c
commit 9addfabc43
1 changed files with 7 additions and 0 deletions

View File

@ -499,3 +499,10 @@ class WaveRNN(nn.Module):
unfolded[start:end] += y[i]
return unfolded
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training