diff --git a/TTS/trainer.py b/TTS/trainer.py index d3d66ab2..c02578d4 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -78,7 +78,7 @@ class TrainingArgs(Coqpit): best_path: str = field( default="", metadata={ - "help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" + "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" }, ) config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) @@ -149,6 +149,7 @@ class Trainer: >>> trainer.fit() TODO: + - Wrap model for not calling .module in DDP. - Accumulate gradients b/w batches. - Deepspeed integration - Profiler integration. @@ -331,7 +332,7 @@ class Trainer: print(" > Restoring Optimizer...") optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]: - print(" > Restoring AMP Scaler...") + print(" > Restoring Scaler...") scaler = _restore_list_objs(checkpoint["scaler"], scaler) except (KeyError, RuntimeError): print(" > Partial model initialization...") @@ -477,7 +478,7 @@ class Trainer: # check nan loss if torch.isnan(loss_dict["loss"]).any(): - raise RuntimeError(f" > Detected NaN loss - {loss_dict}.") + raise RuntimeError(f" > NaN loss detected - {loss_dict}") # set gradient clipping threshold if "grad_clip" in config and config.grad_clip is not None: @@ -819,7 +820,7 @@ class Trainer: def test_run(self) -> None: """Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.""" - if hasattr(self.model, "test_run"): + if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")): if self.eval_loader is None: self.eval_loader = self.get_eval_dataloader( self.ap, @@ -841,13 +842,20 @@ class Trainer: self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_figures(self.total_steps_done, figures) - def _fit(self) -> None: - """🏃 train -> evaluate -> test for the number of epochs.""" + def _restore_best_loss(self): + """Restore the best loss from the args.best_path if provided else + from the model (`args.restore_path` or `args.continue_path`) used for resuming the training""" if self.restore_step != 0 or self.args.best_path: print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") - self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"] + ch = load_fsspec(self.args.restore_path, map_location="cpu") + if "model_loss" in ch: + self.best_loss = ch["model_loss"] print(f" > Starting with loaded last best loss {self.best_loss}.") + def _fit(self) -> None: + """🏃 train -> evaluate -> test for the number of epochs.""" + self._restore_best_loss() + self.total_steps_done = self.restore_step for epoch in range(0, self.config.epochs):