From 2da81f5bb6887ca587d8edb0e8d42d5914a6f361 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 21 Apr 2021 13:09:04 +0200 Subject: [PATCH] add load_chekpoint to speaker encoder --- TTS/speaker_encoder/model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/TTS/speaker_encoder/model.py b/TTS/speaker_encoder/model.py index 7a3dc09c..3d52382a 100644 --- a/TTS/speaker_encoder/model.py +++ b/TTS/speaker_encoder/model.py @@ -108,3 +108,11 @@ class SpeakerEncoder(nn.Module): else: embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) return embed / num_iters + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False): + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training