mirror of https://github.com/coqui-ai/TTS.git
set eval mode whe nloading models
This commit is contained in:
parent
5bd7238153
commit
1bc8fbbd3c
|
@ -7,7 +7,7 @@ from TTS.utils.io import RenamingUnpickler
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False):
|
||||||
"""Load ```TTS.tts.models``` checkpoints.
|
"""Load ```TTS.tts.models``` checkpoints.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -33,6 +33,8 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||||
if hasattr(model.decoder, 'r'):
|
if hasattr(model.decoder, 'r'):
|
||||||
model.decoder.set_r(state['r'])
|
model.decoder.set_r(state['r'])
|
||||||
print(" > Model r: ", state['r'])
|
print(" > Model r: ", state['r'])
|
||||||
|
if eval:
|
||||||
|
model.eval()
|
||||||
return model, state
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import pickle as pickle_tts
|
||||||
from TTS.utils.io import RenamingUnpickler
|
from TTS.utils.io import RenamingUnpickler
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False):
|
||||||
try:
|
try:
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
@ -15,6 +15,8 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||||
model.load_state_dict(state['model'])
|
model.load_state_dict(state['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
if eval:
|
||||||
|
model.eval()
|
||||||
return model, state
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue