From 1faf565e3ae2270e7c603368afa7e2234a5877f2 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 Jan 2021 02:10:56 +0000 Subject: [PATCH] add load_checkpoint func to tts models --- TTS/tts/models/glow_tts.py | 8 ++++++++ TTS/tts/models/speedy_speech.py | 7 +++++++ TTS/tts/models/tacotron_abstract.py | 8 ++++++++ 3 files changed, 23 insertions(+) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index b55ba1b1..c978e4fa 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -223,3 +223,11 @@ class GlowTts(nn.Module): def store_inverse(self): self.decoder.store_inverse() + + def load_checkpoint(self, config, checkpoint_path, eval=False): + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + self.store_inverse() + assert not self.training diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 2e7d0a5f..7f5c660e 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -190,3 +190,10 @@ class SpeedySpeech(nn.Module): y_lengths = o_dr.sum(1) o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) return o_de, attn + + def load_checkpoint(self, config, checkpoint_path, eval=False): + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 54c46be2..0a63b871 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -121,6 +121,14 @@ class TacotronAbstract(ABC, nn.Module): def inference(self): pass + def load_checkpoint(self, config, checkpoint_path, eval=False): + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + self.decoder.set_r(state['r']) + if eval: + self.eval() + assert not self.training + ############################# # COMMON COMPUTE FUNCTIONS #############################