mirror of https://github.com/coqui-ai/TTS.git
Fix restoring best_loss
Keep the default value if model checkpoint has no `model_loss`
This commit is contained in:
parent
c8bbcdfd07
commit
c5d1dd9d1b
|
@ -78,7 +78,7 @@ class TrainingArgs(Coqpit):
|
||||||
best_path: str = field(
|
best_path: str = field(
|
||||||
default="",
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||||
|
@ -149,6 +149,7 @@ class Trainer:
|
||||||
>>> trainer.fit()
|
>>> trainer.fit()
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
|
- Wrap model for not calling .module in DDP.
|
||||||
- Accumulate gradients b/w batches.
|
- Accumulate gradients b/w batches.
|
||||||
- Deepspeed integration
|
- Deepspeed integration
|
||||||
- Profiler integration.
|
- Profiler integration.
|
||||||
|
@ -331,7 +332,7 @@ class Trainer:
|
||||||
print(" > Restoring Optimizer...")
|
print(" > Restoring Optimizer...")
|
||||||
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
|
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
|
||||||
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
||||||
print(" > Restoring AMP Scaler...")
|
print(" > Restoring Scaler...")
|
||||||
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
||||||
except (KeyError, RuntimeError):
|
except (KeyError, RuntimeError):
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
|
@ -477,7 +478,7 @@ class Trainer:
|
||||||
|
|
||||||
# check nan loss
|
# check nan loss
|
||||||
if torch.isnan(loss_dict["loss"]).any():
|
if torch.isnan(loss_dict["loss"]).any():
|
||||||
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
|
raise RuntimeError(f" > NaN loss detected - {loss_dict}")
|
||||||
|
|
||||||
# set gradient clipping threshold
|
# set gradient clipping threshold
|
||||||
if "grad_clip" in config and config.grad_clip is not None:
|
if "grad_clip" in config and config.grad_clip is not None:
|
||||||
|
@ -819,7 +820,7 @@ class Trainer:
|
||||||
def test_run(self) -> None:
|
def test_run(self) -> None:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
Model must return figures and audios to be logged by the Tensorboard."""
|
Model must return figures and audios to be logged by the Tensorboard."""
|
||||||
if hasattr(self.model, "test_run"):
|
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
|
||||||
if self.eval_loader is None:
|
if self.eval_loader is None:
|
||||||
self.eval_loader = self.get_eval_dataloader(
|
self.eval_loader = self.get_eval_dataloader(
|
||||||
self.ap,
|
self.ap,
|
||||||
|
@ -841,13 +842,20 @@ class Trainer:
|
||||||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||||
|
|
||||||
def _fit(self) -> None:
|
def _restore_best_loss(self):
|
||||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
"""Restore the best loss from the args.best_path if provided else
|
||||||
|
from the model (`args.restore_path` or `args.continue_path`) used for resuming the training"""
|
||||||
if self.restore_step != 0 or self.args.best_path:
|
if self.restore_step != 0 or self.args.best_path:
|
||||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||||
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
|
ch = load_fsspec(self.args.restore_path, map_location="cpu")
|
||||||
|
if "model_loss" in ch:
|
||||||
|
self.best_loss = ch["model_loss"]
|
||||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||||
|
|
||||||
|
def _fit(self) -> None:
|
||||||
|
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||||
|
self._restore_best_loss()
|
||||||
|
|
||||||
self.total_steps_done = self.restore_step
|
self.total_steps_done = self.restore_step
|
||||||
|
|
||||||
for epoch in range(0, self.config.epochs):
|
for epoch in range(0, self.config.epochs):
|
||||||
|
|
Loading…
Reference in New Issue