set eval mode whe nloading models

This commit is contained in:
root 2021-01-20 02:14:18 +00:00
parent 5bd7238153
commit 1bc8fbbd3c
2 changed files with 6 additions and 2 deletions

View File

@ -7,7 +7,7 @@ from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False):
"""Load ```TTS.tts.models``` checkpoints.
Args:
@ -33,6 +33,8 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
if hasattr(model.decoder, 'r'):
model.decoder.set_r(state['r'])
print(" > Model r: ", state['r'])
if eval:
model.eval()
return model, state

View File

@ -6,7 +6,7 @@ import pickle as pickle_tts
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, use_cuda=False):
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False):
try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
except ModuleNotFoundError:
@ -15,6 +15,8 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
model.load_state_dict(state['model'])
if use_cuda:
model.cuda()
if eval:
model.eval()
return model, state