add load_checkpoint func to tts models

This commit is contained in:
root 2021-01-20 02:10:56 +00:00
parent 5c87753e88
commit 1faf565e3a
3 changed files with 23 additions and 0 deletions

View File

@ -223,3 +223,11 @@ class GlowTts(nn.Module):
def store_inverse(self):
self.decoder.store_inverse()
def load_checkpoint(self, config, checkpoint_path, eval=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
self.store_inverse()
assert not self.training

View File

@ -190,3 +190,10 @@ class SpeedySpeech(nn.Module):
y_lengths = o_dr.sum(1)
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
return o_de, attn
def load_checkpoint(self, config, checkpoint_path, eval=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training

View File

@ -121,6 +121,14 @@ class TacotronAbstract(ABC, nn.Module):
def inference(self):
pass
def load_checkpoint(self, config, checkpoint_path, eval=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
self.decoder.set_r(state['r'])
if eval:
self.eval()
assert not self.training
#############################
# COMMON COMPUTE FUNCTIONS
#############################