Fix callbacks against multi-gpu training

This commit is contained in:
Eren Gölge 2021-12-01 10:32:14 +00:00
parent 9a145c9b88
commit 512ada7548
2 changed files with 86 additions and 60 deletions

View File

@ -284,8 +284,8 @@ class Trainer:
self.optimizer = self.get_optimizer(self.model, self.config)
# CALLBACK
self.callbacks = TrainerCallback(self)
self.callbacks.on_init_start()
self.callbacks = TrainerCallback()
self.callbacks.on_init_start(self)
# init AMP
if self.use_amp_scaler:
@ -324,7 +324,7 @@ class Trainer:
num_params = count_parameters(self.model)
print("\n > Model has {} parameters".format(num_params))
self.callbacks.on_init_end()
self.callbacks.on_init_end(self)
@staticmethod
def parse_argv(args: Union[Coqpit, List]):
@ -677,7 +677,7 @@ class Trainer:
Returns:
Tuple[Dict, Dict]: Model outputs and losses.
"""
self.callbacks.on_train_step_start()
self.callbacks.on_train_step_start(self)
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
@ -792,7 +792,7 @@ class Trainer:
self.dashboard_logger.flush()
self.total_steps_done += 1
self.callbacks.on_train_step_end()
self.callbacks.on_train_step_end(self)
return outputs, loss_dict
def train_epoch(self) -> None:
@ -983,7 +983,7 @@ class Trainer:
if self.num_gpus > 1:
# let all processes sync up before starting with a new epoch of training
dist.barrier()
self.callbacks.on_epoch_start()
self.callbacks.on_epoch_start(self)
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch
@ -999,7 +999,7 @@ class Trainer:
)
if self.args.rank in [None, 0]:
self.save_best_model()
self.callbacks.on_epoch_end()
self.callbacks.on_epoch_end(self)
def fit(self) -> None:
"""Where the ✨magic✨ happens..."""
@ -1008,7 +1008,7 @@ class Trainer:
if self.args.rank == 0:
self.dashboard_logger.finish()
except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt()
self.callbacks.on_keyboard_interrupt(self)
# if the output folder is empty remove the run.
remove_experiment_folder(self.output_path)
# clear the DDP processes

View File

@ -1,75 +1,101 @@
class TrainerCallback:
def __init__(self, trainer):
def __init__(self):
super().__init__()
self.trainer = trainer
def on_init_start(self) -> None:
if hasattr(self.trainer.model, "on_init_start"):
self.trainer.model.on_init_start(self.trainer)
def on_init_start(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_start"):
trainer.model.module.on_init_start(trainer)
else:
if hasattr(trainer.model, "on_init_start"):
trainer.model.on_init_start(trainer)
if hasattr(self.trainer.criterion, "on_init_start"):
self.trainer.criterion.on_init_start(self.trainer)
if hasattr(trainer.criterion, "on_init_start"):
trainer.criterion.on_init_start(trainer)
if hasattr(self.trainer.optimizer, "on_init_start"):
self.trainer.optimizer.on_init_start(self.trainer)
if hasattr(trainer.optimizer, "on_init_start"):
trainer.optimizer.on_init_start(trainer)
def on_init_end(self) -> None:
if hasattr(self.trainer.model, "on_init_end"):
self.trainer.model.on_init_end(self.trainer)
def on_init_end(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_end"):
trainer.model.module.on_init_end(trainer)
else:
if hasattr(trainer.model, "on_init_end"):
trainer.model.on_init_end(trainer)
if hasattr(self.trainer.criterion, "on_init_end"):
self.trainer.criterion.on_init_end(self.trainer)
if hasattr(trainer.criterion, "on_init_end"):
trainer.criterion.on_init_end(trainer)
if hasattr(self.trainer.optimizer, "on_init_end"):
self.trainer.optimizer.on_init_end(self.trainer)
if hasattr(trainer.optimizer, "on_init_end"):
trainer.optimizer.on_init_end(trainer)
def on_epoch_start(self) -> None:
if hasattr(self.trainer.model, "on_epoch_start"):
self.trainer.model.on_epoch_start(self.trainer)
def on_epoch_start(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_start"):
trainer.model.module.on_epoch_start(trainer)
else:
if hasattr(trainer.model, "on_epoch_start"):
trainer.model.on_epoch_start(trainer)
if hasattr(self.trainer.criterion, "on_epoch_start"):
self.trainer.criterion.on_epoch_start(self.trainer)
if hasattr(trainer.criterion, "on_epoch_start"):
trainer.criterion.on_epoch_start(trainer)
if hasattr(self.trainer.optimizer, "on_epoch_start"):
self.trainer.optimizer.on_epoch_start(self.trainer)
if hasattr(trainer.optimizer, "on_epoch_start"):
trainer.optimizer.on_epoch_start(trainer)
def on_epoch_end(self) -> None:
if hasattr(self.trainer.model, "on_epoch_end"):
self.trainer.model.on_epoch_end(self.trainer)
def on_epoch_end(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_end"):
trainer.model.module.on_epoch_end(trainer)
else:
if hasattr(trainer.model, "on_epoch_end"):
trainer.model.on_epoch_end(trainer)
if hasattr(self.trainer.criterion, "on_epoch_end"):
self.trainer.criterion.on_epoch_end(self.trainer)
if hasattr(trainer.criterion, "on_epoch_end"):
trainer.criterion.on_epoch_end(trainer)
if hasattr(self.trainer.optimizer, "on_epoch_end"):
self.trainer.optimizer.on_epoch_end(self.trainer)
if hasattr(trainer.optimizer, "on_epoch_end"):
trainer.optimizer.on_epoch_end(trainer)
def on_train_step_start(self) -> None:
if hasattr(self.trainer.model, "on_train_step_start"):
self.trainer.model.on_train_step_start(self.trainer)
def on_train_step_start(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_start"):
trainer.model.module.on_train_step_start(trainer)
else:
if hasattr(trainer.model, "on_train_step_start"):
trainer.model.on_train_step_start(trainer)
if hasattr(self.trainer.criterion, "on_train_step_start"):
self.trainer.criterion.on_train_step_start(self.trainer)
if hasattr(trainer.criterion, "on_train_step_start"):
trainer.criterion.on_train_step_start(trainer)
if hasattr(self.trainer.optimizer, "on_train_step_start"):
self.trainer.optimizer.on_train_step_start(self.trainer)
if hasattr(trainer.optimizer, "on_train_step_start"):
trainer.optimizer.on_train_step_start(trainer)
def on_train_step_end(self) -> None:
def on_train_step_end(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_end"):
trainer.model.module.on_train_step_end(trainer)
else:
if hasattr(trainer.model, "on_train_step_end"):
trainer.model.on_train_step_end(trainer)
if hasattr(self.trainer.model, "on_train_step_end"):
self.trainer.model.on_train_step_end(self.trainer)
if hasattr(trainer.criterion, "on_train_step_end"):
trainer.criterion.on_train_step_end(trainer)
if hasattr(self.trainer.criterion, "on_train_step_end"):
self.trainer.criterion.on_train_step_end(self.trainer)
if hasattr(trainer.optimizer, "on_train_step_end"):
trainer.optimizer.on_train_step_end(trainer)
if hasattr(self.trainer.optimizer, "on_train_step_end"):
self.trainer.optimizer.on_train_step_end(self.trainer)
def on_keyboard_interrupt(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
trainer.model.module.on_keyboard_interrupt(trainer)
else:
if hasattr(trainer.model, "on_keyboard_interrupt"):
trainer.model.on_keyboard_interrupt(trainer)
def on_keyboard_interrupt(self) -> None:
if hasattr(self.trainer.model, "on_keyboard_interrupt"):
self.trainer.model.on_keyboard_interrupt(self.trainer)
if hasattr(trainer.criterion, "on_keyboard_interrupt"):
trainer.criterion.on_keyboard_interrupt(trainer)
if hasattr(self.trainer.criterion, "on_keyboard_interrupt"):
self.trainer.criterion.on_keyboard_interrupt(self.trainer)
if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"):
self.trainer.optimizer.on_keyboard_interrupt(self.trainer)
if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
trainer.optimizer.on_keyboard_interrupt(trainer)