From 1bc8fbbd3cb95941e15c5006f78f0f985cc48ce2 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 Jan 2021 02:14:18 +0000 Subject: [PATCH] set eval mode whe nloading models --- TTS/tts/utils/io.py | 4 +++- TTS/vocoder/utils/io.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index 830529a3..63e04283 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -7,7 +7,7 @@ from TTS.utils.io import RenamingUnpickler -def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): +def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): """Load ```TTS.tts.models``` checkpoints. Args: @@ -33,6 +33,8 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): if hasattr(model.decoder, 'r'): model.decoder.set_r(state['r']) print(" > Model r: ", state['r']) + if eval: + model.eval() return model, state diff --git a/TTS/vocoder/utils/io.py b/TTS/vocoder/utils/io.py index c33d2cb9..5c42dfca 100644 --- a/TTS/vocoder/utils/io.py +++ b/TTS/vocoder/utils/io.py @@ -6,7 +6,7 @@ import pickle as pickle_tts from TTS.utils.io import RenamingUnpickler -def load_checkpoint(model, checkpoint_path, use_cuda=False): +def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): try: state = torch.load(checkpoint_path, map_location=torch.device('cpu')) except ModuleNotFoundError: @@ -15,6 +15,8 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False): model.load_state_dict(state['model']) if use_cuda: model.cuda() + if eval: + model.eval() return model, state