mirror of https://github.com/coqui-ai/TTS.git
Fix callbacks against multi-gpu training
This commit is contained in:
parent
9a145c9b88
commit
512ada7548
|
@ -284,8 +284,8 @@ class Trainer:
|
||||||
self.optimizer = self.get_optimizer(self.model, self.config)
|
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||||
|
|
||||||
# CALLBACK
|
# CALLBACK
|
||||||
self.callbacks = TrainerCallback(self)
|
self.callbacks = TrainerCallback()
|
||||||
self.callbacks.on_init_start()
|
self.callbacks.on_init_start(self)
|
||||||
|
|
||||||
# init AMP
|
# init AMP
|
||||||
if self.use_amp_scaler:
|
if self.use_amp_scaler:
|
||||||
|
@ -324,7 +324,7 @@ class Trainer:
|
||||||
num_params = count_parameters(self.model)
|
num_params = count_parameters(self.model)
|
||||||
print("\n > Model has {} parameters".format(num_params))
|
print("\n > Model has {} parameters".format(num_params))
|
||||||
|
|
||||||
self.callbacks.on_init_end()
|
self.callbacks.on_init_end(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_argv(args: Union[Coqpit, List]):
|
def parse_argv(args: Union[Coqpit, List]):
|
||||||
|
@ -677,7 +677,7 @@ class Trainer:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: Model outputs and losses.
|
Tuple[Dict, Dict]: Model outputs and losses.
|
||||||
"""
|
"""
|
||||||
self.callbacks.on_train_step_start()
|
self.callbacks.on_train_step_start(self)
|
||||||
# format data
|
# format data
|
||||||
batch = self.format_batch(batch)
|
batch = self.format_batch(batch)
|
||||||
loader_time = time.time() - loader_start_time
|
loader_time = time.time() - loader_start_time
|
||||||
|
@ -792,7 +792,7 @@ class Trainer:
|
||||||
self.dashboard_logger.flush()
|
self.dashboard_logger.flush()
|
||||||
|
|
||||||
self.total_steps_done += 1
|
self.total_steps_done += 1
|
||||||
self.callbacks.on_train_step_end()
|
self.callbacks.on_train_step_end(self)
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_epoch(self) -> None:
|
def train_epoch(self) -> None:
|
||||||
|
@ -983,7 +983,7 @@ class Trainer:
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
# let all processes sync up before starting with a new epoch of training
|
# let all processes sync up before starting with a new epoch of training
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
self.callbacks.on_epoch_start()
|
self.callbacks.on_epoch_start(self)
|
||||||
self.keep_avg_train = KeepAverage()
|
self.keep_avg_train = KeepAverage()
|
||||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||||
self.epochs_done = epoch
|
self.epochs_done = epoch
|
||||||
|
@ -999,7 +999,7 @@ class Trainer:
|
||||||
)
|
)
|
||||||
if self.args.rank in [None, 0]:
|
if self.args.rank in [None, 0]:
|
||||||
self.save_best_model()
|
self.save_best_model()
|
||||||
self.callbacks.on_epoch_end()
|
self.callbacks.on_epoch_end(self)
|
||||||
|
|
||||||
def fit(self) -> None:
|
def fit(self) -> None:
|
||||||
"""Where the ✨️magic✨️ happens..."""
|
"""Where the ✨️magic✨️ happens..."""
|
||||||
|
@ -1008,7 +1008,7 @@ class Trainer:
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
self.dashboard_logger.finish()
|
self.dashboard_logger.finish()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.callbacks.on_keyboard_interrupt()
|
self.callbacks.on_keyboard_interrupt(self)
|
||||||
# if the output folder is empty remove the run.
|
# if the output folder is empty remove the run.
|
||||||
remove_experiment_folder(self.output_path)
|
remove_experiment_folder(self.output_path)
|
||||||
# clear the DDP processes
|
# clear the DDP processes
|
||||||
|
|
|
@ -1,75 +1,101 @@
|
||||||
class TrainerCallback:
|
class TrainerCallback:
|
||||||
def __init__(self, trainer):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.trainer = trainer
|
|
||||||
|
|
||||||
def on_init_start(self) -> None:
|
def on_init_start(self, trainer) -> None:
|
||||||
if hasattr(self.trainer.model, "on_init_start"):
|
if hasattr(trainer.model, "module"):
|
||||||
self.trainer.model.on_init_start(self.trainer)
|
if hasattr(trainer.model.module, "on_init_start"):
|
||||||
|
trainer.model.module.on_init_start(trainer)
|
||||||
|
else:
|
||||||
|
if hasattr(trainer.model, "on_init_start"):
|
||||||
|
trainer.model.on_init_start(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.criterion, "on_init_start"):
|
if hasattr(trainer.criterion, "on_init_start"):
|
||||||
self.trainer.criterion.on_init_start(self.trainer)
|
trainer.criterion.on_init_start(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.optimizer, "on_init_start"):
|
if hasattr(trainer.optimizer, "on_init_start"):
|
||||||
self.trainer.optimizer.on_init_start(self.trainer)
|
trainer.optimizer.on_init_start(trainer)
|
||||||
|
|
||||||
def on_init_end(self) -> None:
|
def on_init_end(self, trainer) -> None:
|
||||||
if hasattr(self.trainer.model, "on_init_end"):
|
if hasattr(trainer.model, "module"):
|
||||||
self.trainer.model.on_init_end(self.trainer)
|
if hasattr(trainer.model.module, "on_init_end"):
|
||||||
|
trainer.model.module.on_init_end(trainer)
|
||||||
|
else:
|
||||||
|
if hasattr(trainer.model, "on_init_end"):
|
||||||
|
trainer.model.on_init_end(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.criterion, "on_init_end"):
|
if hasattr(trainer.criterion, "on_init_end"):
|
||||||
self.trainer.criterion.on_init_end(self.trainer)
|
trainer.criterion.on_init_end(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.optimizer, "on_init_end"):
|
if hasattr(trainer.optimizer, "on_init_end"):
|
||||||
self.trainer.optimizer.on_init_end(self.trainer)
|
trainer.optimizer.on_init_end(trainer)
|
||||||
|
|
||||||
def on_epoch_start(self) -> None:
|
def on_epoch_start(self, trainer) -> None:
|
||||||
if hasattr(self.trainer.model, "on_epoch_start"):
|
if hasattr(trainer.model, "module"):
|
||||||
self.trainer.model.on_epoch_start(self.trainer)
|
if hasattr(trainer.model.module, "on_epoch_start"):
|
||||||
|
trainer.model.module.on_epoch_start(trainer)
|
||||||
|
else:
|
||||||
|
if hasattr(trainer.model, "on_epoch_start"):
|
||||||
|
trainer.model.on_epoch_start(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.criterion, "on_epoch_start"):
|
if hasattr(trainer.criterion, "on_epoch_start"):
|
||||||
self.trainer.criterion.on_epoch_start(self.trainer)
|
trainer.criterion.on_epoch_start(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.optimizer, "on_epoch_start"):
|
if hasattr(trainer.optimizer, "on_epoch_start"):
|
||||||
self.trainer.optimizer.on_epoch_start(self.trainer)
|
trainer.optimizer.on_epoch_start(trainer)
|
||||||
|
|
||||||
def on_epoch_end(self) -> None:
|
def on_epoch_end(self, trainer) -> None:
|
||||||
if hasattr(self.trainer.model, "on_epoch_end"):
|
if hasattr(trainer.model, "module"):
|
||||||
self.trainer.model.on_epoch_end(self.trainer)
|
if hasattr(trainer.model.module, "on_epoch_end"):
|
||||||
|
trainer.model.module.on_epoch_end(trainer)
|
||||||
|
else:
|
||||||
|
if hasattr(trainer.model, "on_epoch_end"):
|
||||||
|
trainer.model.on_epoch_end(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.criterion, "on_epoch_end"):
|
if hasattr(trainer.criterion, "on_epoch_end"):
|
||||||
self.trainer.criterion.on_epoch_end(self.trainer)
|
trainer.criterion.on_epoch_end(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.optimizer, "on_epoch_end"):
|
if hasattr(trainer.optimizer, "on_epoch_end"):
|
||||||
self.trainer.optimizer.on_epoch_end(self.trainer)
|
trainer.optimizer.on_epoch_end(trainer)
|
||||||
|
|
||||||
def on_train_step_start(self) -> None:
|
def on_train_step_start(self, trainer) -> None:
|
||||||
if hasattr(self.trainer.model, "on_train_step_start"):
|
if hasattr(trainer.model, "module"):
|
||||||
self.trainer.model.on_train_step_start(self.trainer)
|
if hasattr(trainer.model.module, "on_train_step_start"):
|
||||||
|
trainer.model.module.on_train_step_start(trainer)
|
||||||
|
else:
|
||||||
|
if hasattr(trainer.model, "on_train_step_start"):
|
||||||
|
trainer.model.on_train_step_start(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.criterion, "on_train_step_start"):
|
if hasattr(trainer.criterion, "on_train_step_start"):
|
||||||
self.trainer.criterion.on_train_step_start(self.trainer)
|
trainer.criterion.on_train_step_start(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.optimizer, "on_train_step_start"):
|
if hasattr(trainer.optimizer, "on_train_step_start"):
|
||||||
self.trainer.optimizer.on_train_step_start(self.trainer)
|
trainer.optimizer.on_train_step_start(trainer)
|
||||||
|
|
||||||
def on_train_step_end(self) -> None:
|
def on_train_step_end(self, trainer) -> None:
|
||||||
|
if hasattr(trainer.model, "module"):
|
||||||
|
if hasattr(trainer.model.module, "on_train_step_end"):
|
||||||
|
trainer.model.module.on_train_step_end(trainer)
|
||||||
|
else:
|
||||||
|
if hasattr(trainer.model, "on_train_step_end"):
|
||||||
|
trainer.model.on_train_step_end(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.model, "on_train_step_end"):
|
if hasattr(trainer.criterion, "on_train_step_end"):
|
||||||
self.trainer.model.on_train_step_end(self.trainer)
|
trainer.criterion.on_train_step_end(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.criterion, "on_train_step_end"):
|
if hasattr(trainer.optimizer, "on_train_step_end"):
|
||||||
self.trainer.criterion.on_train_step_end(self.trainer)
|
trainer.optimizer.on_train_step_end(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.optimizer, "on_train_step_end"):
|
def on_keyboard_interrupt(self, trainer) -> None:
|
||||||
self.trainer.optimizer.on_train_step_end(self.trainer)
|
if hasattr(trainer.model, "module"):
|
||||||
|
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
|
||||||
|
trainer.model.module.on_keyboard_interrupt(trainer)
|
||||||
|
else:
|
||||||
|
if hasattr(trainer.model, "on_keyboard_interrupt"):
|
||||||
|
trainer.model.on_keyboard_interrupt(trainer)
|
||||||
|
|
||||||
def on_keyboard_interrupt(self) -> None:
|
if hasattr(trainer.criterion, "on_keyboard_interrupt"):
|
||||||
if hasattr(self.trainer.model, "on_keyboard_interrupt"):
|
trainer.criterion.on_keyboard_interrupt(trainer)
|
||||||
self.trainer.model.on_keyboard_interrupt(self.trainer)
|
|
||||||
|
|
||||||
if hasattr(self.trainer.criterion, "on_keyboard_interrupt"):
|
if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
|
||||||
self.trainer.criterion.on_keyboard_interrupt(self.trainer)
|
trainer.optimizer.on_keyboard_interrupt(trainer)
|
||||||
|
|
||||||
if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"):
|
|
||||||
self.trainer.optimizer.on_keyboard_interrupt(self.trainer)
|
|
||||||
|
|
Loading…
Reference in New Issue