Fix loading the `amp` scaler from a checkpoint 🛠️

This commit is contained in:
Eren Gölge 2021-06-23 11:08:34 +02:00
parent a7617d8ab6
commit fbba37e01e
1 changed files with 1 additions and 1 deletions

View File

@ -306,7 +306,7 @@ class Trainer:
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...") print(" > Restoring Optimizer...")
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
if "scaler" in checkpoint and self.use_amp_scaler: if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
print(" > Restoring AMP Scaler...") print(" > Restoring AMP Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler) scaler = _restore_list_objs(checkpoint["scaler"], scaler)
except (KeyError, RuntimeError): except (KeyError, RuntimeError):