mirror of https://github.com/coqui-ai/TTS.git
set eval mode whe nloading models
This commit is contained in:
parent
5bd7238153
commit
1bc8fbbd3c
|
@ -7,7 +7,7 @@ from TTS.utils.io import RenamingUnpickler
|
|||
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False):
|
||||
"""Load ```TTS.tts.models``` checkpoints.
|
||||
|
||||
Args:
|
||||
|
@ -33,6 +33,8 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
|||
if hasattr(model.decoder, 'r'):
|
||||
model.decoder.set_r(state['r'])
|
||||
print(" > Model r: ", state['r'])
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import pickle as pickle_tts
|
|||
from TTS.utils.io import RenamingUnpickler
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False):
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
except ModuleNotFoundError:
|
||||
|
@ -15,6 +15,8 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
|||
model.load_state_dict(state['model'])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue