Fix callbacks against multi-gpu training

This commit is contained in:
Eren Gölge 2021-12-01 10:32:14 +00:00
parent 9a145c9b88
commit 512ada7548
2 changed files with 86 additions and 60 deletions

View File

@ -284,8 +284,8 @@ class Trainer:
self.optimizer = self.get_optimizer(self.model, self.config) self.optimizer = self.get_optimizer(self.model, self.config)
# CALLBACK # CALLBACK
self.callbacks = TrainerCallback(self) self.callbacks = TrainerCallback()
self.callbacks.on_init_start() self.callbacks.on_init_start(self)
# init AMP # init AMP
if self.use_amp_scaler: if self.use_amp_scaler:
@ -324,7 +324,7 @@ class Trainer:
num_params = count_parameters(self.model) num_params = count_parameters(self.model)
print("\n > Model has {} parameters".format(num_params)) print("\n > Model has {} parameters".format(num_params))
self.callbacks.on_init_end() self.callbacks.on_init_end(self)
@staticmethod @staticmethod
def parse_argv(args: Union[Coqpit, List]): def parse_argv(args: Union[Coqpit, List]):
@ -677,7 +677,7 @@ class Trainer:
Returns: Returns:
Tuple[Dict, Dict]: Model outputs and losses. Tuple[Dict, Dict]: Model outputs and losses.
""" """
self.callbacks.on_train_step_start() self.callbacks.on_train_step_start(self)
# format data # format data
batch = self.format_batch(batch) batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time loader_time = time.time() - loader_start_time
@ -792,7 +792,7 @@ class Trainer:
self.dashboard_logger.flush() self.dashboard_logger.flush()
self.total_steps_done += 1 self.total_steps_done += 1
self.callbacks.on_train_step_end() self.callbacks.on_train_step_end(self)
return outputs, loss_dict return outputs, loss_dict
def train_epoch(self) -> None: def train_epoch(self) -> None:
@ -983,7 +983,7 @@ class Trainer:
if self.num_gpus > 1: if self.num_gpus > 1:
# let all processes sync up before starting with a new epoch of training # let all processes sync up before starting with a new epoch of training
dist.barrier() dist.barrier()
self.callbacks.on_epoch_start() self.callbacks.on_epoch_start(self)
self.keep_avg_train = KeepAverage() self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch self.epochs_done = epoch
@ -999,7 +999,7 @@ class Trainer:
) )
if self.args.rank in [None, 0]: if self.args.rank in [None, 0]:
self.save_best_model() self.save_best_model()
self.callbacks.on_epoch_end() self.callbacks.on_epoch_end(self)
def fit(self) -> None: def fit(self) -> None:
"""Where the ✨magic✨ happens...""" """Where the ✨magic✨ happens..."""
@ -1008,7 +1008,7 @@ class Trainer:
if self.args.rank == 0: if self.args.rank == 0:
self.dashboard_logger.finish() self.dashboard_logger.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt() self.callbacks.on_keyboard_interrupt(self)
# if the output folder is empty remove the run. # if the output folder is empty remove the run.
remove_experiment_folder(self.output_path) remove_experiment_folder(self.output_path)
# clear the DDP processes # clear the DDP processes

View File

@ -1,75 +1,101 @@
class TrainerCallback: class TrainerCallback:
def __init__(self, trainer): def __init__(self):
super().__init__() super().__init__()
self.trainer = trainer
def on_init_start(self) -> None: def on_init_start(self, trainer) -> None:
if hasattr(self.trainer.model, "on_init_start"): if hasattr(trainer.model, "module"):
self.trainer.model.on_init_start(self.trainer) if hasattr(trainer.model.module, "on_init_start"):
trainer.model.module.on_init_start(trainer)
else:
if hasattr(trainer.model, "on_init_start"):
trainer.model.on_init_start(trainer)
if hasattr(self.trainer.criterion, "on_init_start"): if hasattr(trainer.criterion, "on_init_start"):
self.trainer.criterion.on_init_start(self.trainer) trainer.criterion.on_init_start(trainer)
if hasattr(self.trainer.optimizer, "on_init_start"): if hasattr(trainer.optimizer, "on_init_start"):
self.trainer.optimizer.on_init_start(self.trainer) trainer.optimizer.on_init_start(trainer)
def on_init_end(self) -> None: def on_init_end(self, trainer) -> None:
if hasattr(self.trainer.model, "on_init_end"): if hasattr(trainer.model, "module"):
self.trainer.model.on_init_end(self.trainer) if hasattr(trainer.model.module, "on_init_end"):
trainer.model.module.on_init_end(trainer)
else:
if hasattr(trainer.model, "on_init_end"):
trainer.model.on_init_end(trainer)
if hasattr(self.trainer.criterion, "on_init_end"): if hasattr(trainer.criterion, "on_init_end"):
self.trainer.criterion.on_init_end(self.trainer) trainer.criterion.on_init_end(trainer)
if hasattr(self.trainer.optimizer, "on_init_end"): if hasattr(trainer.optimizer, "on_init_end"):
self.trainer.optimizer.on_init_end(self.trainer) trainer.optimizer.on_init_end(trainer)
def on_epoch_start(self) -> None: def on_epoch_start(self, trainer) -> None:
if hasattr(self.trainer.model, "on_epoch_start"): if hasattr(trainer.model, "module"):
self.trainer.model.on_epoch_start(self.trainer) if hasattr(trainer.model.module, "on_epoch_start"):
trainer.model.module.on_epoch_start(trainer)
else:
if hasattr(trainer.model, "on_epoch_start"):
trainer.model.on_epoch_start(trainer)
if hasattr(self.trainer.criterion, "on_epoch_start"): if hasattr(trainer.criterion, "on_epoch_start"):
self.trainer.criterion.on_epoch_start(self.trainer) trainer.criterion.on_epoch_start(trainer)
if hasattr(self.trainer.optimizer, "on_epoch_start"): if hasattr(trainer.optimizer, "on_epoch_start"):
self.trainer.optimizer.on_epoch_start(self.trainer) trainer.optimizer.on_epoch_start(trainer)
def on_epoch_end(self) -> None: def on_epoch_end(self, trainer) -> None:
if hasattr(self.trainer.model, "on_epoch_end"): if hasattr(trainer.model, "module"):
self.trainer.model.on_epoch_end(self.trainer) if hasattr(trainer.model.module, "on_epoch_end"):
trainer.model.module.on_epoch_end(trainer)
else:
if hasattr(trainer.model, "on_epoch_end"):
trainer.model.on_epoch_end(trainer)
if hasattr(self.trainer.criterion, "on_epoch_end"): if hasattr(trainer.criterion, "on_epoch_end"):
self.trainer.criterion.on_epoch_end(self.trainer) trainer.criterion.on_epoch_end(trainer)
if hasattr(self.trainer.optimizer, "on_epoch_end"): if hasattr(trainer.optimizer, "on_epoch_end"):
self.trainer.optimizer.on_epoch_end(self.trainer) trainer.optimizer.on_epoch_end(trainer)
def on_train_step_start(self) -> None: def on_train_step_start(self, trainer) -> None:
if hasattr(self.trainer.model, "on_train_step_start"): if hasattr(trainer.model, "module"):
self.trainer.model.on_train_step_start(self.trainer) if hasattr(trainer.model.module, "on_train_step_start"):
trainer.model.module.on_train_step_start(trainer)
else:
if hasattr(trainer.model, "on_train_step_start"):
trainer.model.on_train_step_start(trainer)
if hasattr(self.trainer.criterion, "on_train_step_start"): if hasattr(trainer.criterion, "on_train_step_start"):
self.trainer.criterion.on_train_step_start(self.trainer) trainer.criterion.on_train_step_start(trainer)
if hasattr(self.trainer.optimizer, "on_train_step_start"): if hasattr(trainer.optimizer, "on_train_step_start"):
self.trainer.optimizer.on_train_step_start(self.trainer) trainer.optimizer.on_train_step_start(trainer)
def on_train_step_end(self) -> None: def on_train_step_end(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_end"):
trainer.model.module.on_train_step_end(trainer)
else:
if hasattr(trainer.model, "on_train_step_end"):
trainer.model.on_train_step_end(trainer)
if hasattr(self.trainer.model, "on_train_step_end"): if hasattr(trainer.criterion, "on_train_step_end"):
self.trainer.model.on_train_step_end(self.trainer) trainer.criterion.on_train_step_end(trainer)
if hasattr(self.trainer.criterion, "on_train_step_end"): if hasattr(trainer.optimizer, "on_train_step_end"):
self.trainer.criterion.on_train_step_end(self.trainer) trainer.optimizer.on_train_step_end(trainer)
if hasattr(self.trainer.optimizer, "on_train_step_end"): def on_keyboard_interrupt(self, trainer) -> None:
self.trainer.optimizer.on_train_step_end(self.trainer) if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
trainer.model.module.on_keyboard_interrupt(trainer)
else:
if hasattr(trainer.model, "on_keyboard_interrupt"):
trainer.model.on_keyboard_interrupt(trainer)
def on_keyboard_interrupt(self) -> None: if hasattr(trainer.criterion, "on_keyboard_interrupt"):
if hasattr(self.trainer.model, "on_keyboard_interrupt"): trainer.criterion.on_keyboard_interrupt(trainer)
self.trainer.model.on_keyboard_interrupt(self.trainer)
if hasattr(self.trainer.criterion, "on_keyboard_interrupt"): if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
self.trainer.criterion.on_keyboard_interrupt(self.trainer) trainer.optimizer.on_keyboard_interrupt(trainer)
if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"):
self.trainer.optimizer.on_keyboard_interrupt(self.trainer)