mirror of https://github.com/coqui-ai/TTS.git
wavernn load_checkpoint function
This commit is contained in:
parent
50fee59a2c
commit
9addfabc43
|
@ -499,3 +499,10 @@ class WaveRNN(nn.Module):
|
||||||
unfolded[start:end] += y[i]
|
unfolded[start:end] += y[i]
|
||||||
|
|
||||||
return unfolded
|
return unfolded
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
self.load_state_dict(state['model'])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
Loading…
Reference in New Issue