mirror of https://github.com/coqui-ai/TTS.git
add load_chekpoint to speaker encoder
This commit is contained in:
parent
1229ccbf07
commit
2da81f5bb6
|
@ -108,3 +108,11 @@ class SpeakerEncoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
|
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
|
||||||
return embed / num_iters
|
return embed / num_iters
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument, redefined-builtin
|
||||||
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False):
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
self.load_state_dict(state["model"])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
Loading…
Reference in New Issue