Fix loading the `amp` scaler from a checkpoint 🛠️

This commit is contained in:
Eren Gölge 2021-06-23 11:08:34 +02:00
parent a7617d8ab6
commit fbba37e01e
1 changed files with 1 additions and 1 deletions

View File

@ -306,7 +306,7 @@ class Trainer:
model.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...")
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
if "scaler" in checkpoint and self.use_amp_scaler:
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
print(" > Restoring AMP Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
except (KeyError, RuntimeError):