mirror of https://github.com/coqui-ai/TTS.git
add load_checkpoint func to tts models
This commit is contained in:
parent
5c87753e88
commit
1faf565e3a
|
@ -223,3 +223,11 @@ class GlowTts(nn.Module):
|
||||||
|
|
||||||
def store_inverse(self):
|
def store_inverse(self):
|
||||||
self.decoder.store_inverse()
|
self.decoder.store_inverse()
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
self.load_state_dict(state['model'])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
self.store_inverse()
|
||||||
|
assert not self.training
|
||||||
|
|
|
@ -190,3 +190,10 @@ class SpeedySpeech(nn.Module):
|
||||||
y_lengths = o_dr.sum(1)
|
y_lengths = o_dr.sum(1)
|
||||||
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||||
return o_de, attn
|
return o_de, attn
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
self.load_state_dict(state['model'])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
|
@ -121,6 +121,14 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
def inference(self):
|
def inference(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
self.load_state_dict(state['model'])
|
||||||
|
self.decoder.set_r(state['r'])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# COMMON COMPUTE FUNCTIONS
|
# COMMON COMPUTE FUNCTIONS
|
||||||
#############################
|
#############################
|
||||||
|
|
Loading…
Reference in New Issue