diff --git a/TTS/trainer.py b/TTS/trainer.py index f628d9a4..d5aec1c9 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -306,7 +306,7 @@ class Trainer: model.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) - if "scaler" in checkpoint and self.use_amp_scaler: + if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: print(" > Restoring AMP Scaler...") scaler = _restore_list_objs(checkpoint["scaler"], scaler) except (KeyError, RuntimeError):