mirror of https://github.com/coqui-ai/TTS.git
Fix callbacks against multi-gpu training
This commit is contained in:
parent
9a145c9b88
commit
512ada7548
|
@ -284,8 +284,8 @@ class Trainer:
|
|||
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||
|
||||
# CALLBACK
|
||||
self.callbacks = TrainerCallback(self)
|
||||
self.callbacks.on_init_start()
|
||||
self.callbacks = TrainerCallback()
|
||||
self.callbacks.on_init_start(self)
|
||||
|
||||
# init AMP
|
||||
if self.use_amp_scaler:
|
||||
|
@ -324,7 +324,7 @@ class Trainer:
|
|||
num_params = count_parameters(self.model)
|
||||
print("\n > Model has {} parameters".format(num_params))
|
||||
|
||||
self.callbacks.on_init_end()
|
||||
self.callbacks.on_init_end(self)
|
||||
|
||||
@staticmethod
|
||||
def parse_argv(args: Union[Coqpit, List]):
|
||||
|
@ -677,7 +677,7 @@ class Trainer:
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Model outputs and losses.
|
||||
"""
|
||||
self.callbacks.on_train_step_start()
|
||||
self.callbacks.on_train_step_start(self)
|
||||
# format data
|
||||
batch = self.format_batch(batch)
|
||||
loader_time = time.time() - loader_start_time
|
||||
|
@ -792,7 +792,7 @@ class Trainer:
|
|||
self.dashboard_logger.flush()
|
||||
|
||||
self.total_steps_done += 1
|
||||
self.callbacks.on_train_step_end()
|
||||
self.callbacks.on_train_step_end(self)
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
|
@ -983,7 +983,7 @@ class Trainer:
|
|||
if self.num_gpus > 1:
|
||||
# let all processes sync up before starting with a new epoch of training
|
||||
dist.barrier()
|
||||
self.callbacks.on_epoch_start()
|
||||
self.callbacks.on_epoch_start(self)
|
||||
self.keep_avg_train = KeepAverage()
|
||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||
self.epochs_done = epoch
|
||||
|
@ -999,7 +999,7 @@ class Trainer:
|
|||
)
|
||||
if self.args.rank in [None, 0]:
|
||||
self.save_best_model()
|
||||
self.callbacks.on_epoch_end()
|
||||
self.callbacks.on_epoch_end(self)
|
||||
|
||||
def fit(self) -> None:
|
||||
"""Where the ✨️magic✨️ happens..."""
|
||||
|
@ -1008,7 +1008,7 @@ class Trainer:
|
|||
if self.args.rank == 0:
|
||||
self.dashboard_logger.finish()
|
||||
except KeyboardInterrupt:
|
||||
self.callbacks.on_keyboard_interrupt()
|
||||
self.callbacks.on_keyboard_interrupt(self)
|
||||
# if the output folder is empty remove the run.
|
||||
remove_experiment_folder(self.output_path)
|
||||
# clear the DDP processes
|
||||
|
|
|
@ -1,75 +1,101 @@
|
|||
class TrainerCallback:
|
||||
def __init__(self, trainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.trainer = trainer
|
||||
|
||||
def on_init_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_init_start"):
|
||||
self.trainer.model.on_init_start(self.trainer)
|
||||
def on_init_start(self, trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_init_start"):
|
||||
trainer.model.module.on_init_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_init_start"):
|
||||
trainer.model.on_init_start(trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_init_start"):
|
||||
self.trainer.criterion.on_init_start(self.trainer)
|
||||
if hasattr(trainer.criterion, "on_init_start"):
|
||||
trainer.criterion.on_init_start(trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_init_start"):
|
||||
self.trainer.optimizer.on_init_start(self.trainer)
|
||||
if hasattr(trainer.optimizer, "on_init_start"):
|
||||
trainer.optimizer.on_init_start(trainer)
|
||||
|
||||
def on_init_end(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_init_end"):
|
||||
self.trainer.model.on_init_end(self.trainer)
|
||||
def on_init_end(self, trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_init_end"):
|
||||
trainer.model.module.on_init_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_init_end"):
|
||||
trainer.model.on_init_end(trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_init_end"):
|
||||
self.trainer.criterion.on_init_end(self.trainer)
|
||||
if hasattr(trainer.criterion, "on_init_end"):
|
||||
trainer.criterion.on_init_end(trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_init_end"):
|
||||
self.trainer.optimizer.on_init_end(self.trainer)
|
||||
if hasattr(trainer.optimizer, "on_init_end"):
|
||||
trainer.optimizer.on_init_end(trainer)
|
||||
|
||||
def on_epoch_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_epoch_start"):
|
||||
self.trainer.model.on_epoch_start(self.trainer)
|
||||
def on_epoch_start(self, trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_epoch_start"):
|
||||
trainer.model.module.on_epoch_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_epoch_start"):
|
||||
trainer.model.on_epoch_start(trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_epoch_start"):
|
||||
self.trainer.criterion.on_epoch_start(self.trainer)
|
||||
if hasattr(trainer.criterion, "on_epoch_start"):
|
||||
trainer.criterion.on_epoch_start(trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_epoch_start"):
|
||||
self.trainer.optimizer.on_epoch_start(self.trainer)
|
||||
if hasattr(trainer.optimizer, "on_epoch_start"):
|
||||
trainer.optimizer.on_epoch_start(trainer)
|
||||
|
||||
def on_epoch_end(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_epoch_end"):
|
||||
self.trainer.model.on_epoch_end(self.trainer)
|
||||
def on_epoch_end(self, trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_epoch_end"):
|
||||
trainer.model.module.on_epoch_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_epoch_end"):
|
||||
trainer.model.on_epoch_end(trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_epoch_end"):
|
||||
self.trainer.criterion.on_epoch_end(self.trainer)
|
||||
if hasattr(trainer.criterion, "on_epoch_end"):
|
||||
trainer.criterion.on_epoch_end(trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_epoch_end"):
|
||||
self.trainer.optimizer.on_epoch_end(self.trainer)
|
||||
if hasattr(trainer.optimizer, "on_epoch_end"):
|
||||
trainer.optimizer.on_epoch_end(trainer)
|
||||
|
||||
def on_train_step_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_train_step_start"):
|
||||
self.trainer.model.on_train_step_start(self.trainer)
|
||||
def on_train_step_start(self, trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_train_step_start"):
|
||||
trainer.model.module.on_train_step_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_train_step_start"):
|
||||
trainer.model.on_train_step_start(trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_train_step_start"):
|
||||
self.trainer.criterion.on_train_step_start(self.trainer)
|
||||
if hasattr(trainer.criterion, "on_train_step_start"):
|
||||
trainer.criterion.on_train_step_start(trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_train_step_start"):
|
||||
self.trainer.optimizer.on_train_step_start(self.trainer)
|
||||
if hasattr(trainer.optimizer, "on_train_step_start"):
|
||||
trainer.optimizer.on_train_step_start(trainer)
|
||||
|
||||
def on_train_step_end(self) -> None:
|
||||
def on_train_step_end(self, trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_train_step_end"):
|
||||
trainer.model.module.on_train_step_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_train_step_end"):
|
||||
trainer.model.on_train_step_end(trainer)
|
||||
|
||||
if hasattr(self.trainer.model, "on_train_step_end"):
|
||||
self.trainer.model.on_train_step_end(self.trainer)
|
||||
if hasattr(trainer.criterion, "on_train_step_end"):
|
||||
trainer.criterion.on_train_step_end(trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_train_step_end"):
|
||||
self.trainer.criterion.on_train_step_end(self.trainer)
|
||||
if hasattr(trainer.optimizer, "on_train_step_end"):
|
||||
trainer.optimizer.on_train_step_end(trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_train_step_end"):
|
||||
self.trainer.optimizer.on_train_step_end(self.trainer)
|
||||
def on_keyboard_interrupt(self, trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
|
||||
trainer.model.module.on_keyboard_interrupt(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_keyboard_interrupt"):
|
||||
trainer.model.on_keyboard_interrupt(trainer)
|
||||
|
||||
def on_keyboard_interrupt(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_keyboard_interrupt"):
|
||||
self.trainer.model.on_keyboard_interrupt(self.trainer)
|
||||
if hasattr(trainer.criterion, "on_keyboard_interrupt"):
|
||||
trainer.criterion.on_keyboard_interrupt(trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_keyboard_interrupt"):
|
||||
self.trainer.criterion.on_keyboard_interrupt(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"):
|
||||
self.trainer.optimizer.on_keyboard_interrupt(self.trainer)
|
||||
if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
|
||||
trainer.optimizer.on_keyboard_interrupt(trainer)
|
||||
|
|
Loading…
Reference in New Issue