mirror of https://github.com/coqui-ai/TTS.git
Fix loading the `amp` scaler from a checkpoint 🛠️
This commit is contained in:
parent
a7617d8ab6
commit
fbba37e01e
|
@ -306,7 +306,7 @@ class Trainer:
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
print(" > Restoring Optimizer...")
|
print(" > Restoring Optimizer...")
|
||||||
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
|
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
|
||||||
if "scaler" in checkpoint and self.use_amp_scaler:
|
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
||||||
print(" > Restoring AMP Scaler...")
|
print(" > Restoring AMP Scaler...")
|
||||||
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
||||||
except (KeyError, RuntimeError):
|
except (KeyError, RuntimeError):
|
||||||
|
|
Loading…
Reference in New Issue