From fbba37e01eb25042de78bf706ba8dea2251bd92c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 23 Jun 2021 11:08:34 +0200 Subject: [PATCH] =?UTF-8?q?Fix=20loading=20the=20`amp`=20scaler=20from=20a?= =?UTF-8?q?=20checkpoint=20=F0=9F=9B=A0=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index f628d9a4..d5aec1c9 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -306,7 +306,7 @@ class Trainer: model.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) - if "scaler" in checkpoint and self.use_amp_scaler: + if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: print(" > Restoring AMP Scaler...") scaler = _restore_list_objs(checkpoint["scaler"], scaler) except (KeyError, RuntimeError):