add load_chekpoint to speaker encoder

This commit is contained in:
Eren Gölge 2021-04-21 13:09:04 +02:00
parent 1229ccbf07
commit 2da81f5bb6
1 changed files with 8 additions and 0 deletions

View File

@ -108,3 +108,11 @@ class SpeakerEncoder(nn.Module):
else:
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
return embed / num_iters
# pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training