Fix restoring best_loss

Keep the default value if model checkpoint has no `model_loss`
This commit is contained in:
Eren Gölge 2021-08-17 11:26:24 +00:00
parent c8bbcdfd07
commit c5d1dd9d1b
1 changed files with 15 additions and 7 deletions

View File

@ -78,7 +78,7 @@ class TrainingArgs(Coqpit):
best_path: str = field( best_path: str = field(
default="", default="",
metadata={ metadata={
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
}, },
) )
config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
@ -149,6 +149,7 @@ class Trainer:
>>> trainer.fit() >>> trainer.fit()
TODO: TODO:
- Wrap model for not calling .module in DDP.
- Accumulate gradients b/w batches. - Accumulate gradients b/w batches.
- Deepspeed integration - Deepspeed integration
- Profiler integration. - Profiler integration.
@ -331,7 +332,7 @@ class Trainer:
print(" > Restoring Optimizer...") print(" > Restoring Optimizer...")
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
print(" > Restoring AMP Scaler...") print(" > Restoring Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler) scaler = _restore_list_objs(checkpoint["scaler"], scaler)
except (KeyError, RuntimeError): except (KeyError, RuntimeError):
print(" > Partial model initialization...") print(" > Partial model initialization...")
@ -477,7 +478,7 @@ class Trainer:
# check nan loss # check nan loss
if torch.isnan(loss_dict["loss"]).any(): if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.") raise RuntimeError(f" > NaN loss detected - {loss_dict}")
# set gradient clipping threshold # set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None: if "grad_clip" in config and config.grad_clip is not None:
@ -819,7 +820,7 @@ class Trainer:
def test_run(self) -> None: def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
if self.eval_loader is None: if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader( self.eval_loader = self.get_eval_dataloader(
self.ap, self.ap,
@ -841,13 +842,20 @@ class Trainer:
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.dashboard_logger.test_figures(self.total_steps_done, figures) self.dashboard_logger.test_figures(self.total_steps_done, figures)
def _fit(self) -> None: def _restore_best_loss(self):
"""🏃 train -> evaluate -> test for the number of epochs.""" """Restore the best loss from the args.best_path if provided else
from the model (`args.restore_path` or `args.continue_path`) used for resuming the training"""
if self.restore_step != 0 or self.args.best_path: if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"] ch = load_fsspec(self.args.restore_path, map_location="cpu")
if "model_loss" in ch:
self.best_loss = ch["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.") print(f" > Starting with loaded last best loss {self.best_loss}.")
def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs."""
self._restore_best_loss()
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs): for epoch in range(0, self.config.epochs):