diff --git a/TTS/trainer.py b/TTS/trainer.py index 9fcd77a7..2a2cfc46 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -284,8 +284,8 @@ class Trainer: self.optimizer = self.get_optimizer(self.model, self.config) # CALLBACK - self.callbacks = TrainerCallback(self) - self.callbacks.on_init_start() + self.callbacks = TrainerCallback() + self.callbacks.on_init_start(self) # init AMP if self.use_amp_scaler: @@ -324,7 +324,7 @@ class Trainer: num_params = count_parameters(self.model) print("\n > Model has {} parameters".format(num_params)) - self.callbacks.on_init_end() + self.callbacks.on_init_end(self) @staticmethod def parse_argv(args: Union[Coqpit, List]): @@ -677,7 +677,7 @@ class Trainer: Returns: Tuple[Dict, Dict]: Model outputs and losses. """ - self.callbacks.on_train_step_start() + self.callbacks.on_train_step_start(self) # format data batch = self.format_batch(batch) loader_time = time.time() - loader_start_time @@ -792,7 +792,7 @@ class Trainer: self.dashboard_logger.flush() self.total_steps_done += 1 - self.callbacks.on_train_step_end() + self.callbacks.on_train_step_end(self) return outputs, loss_dict def train_epoch(self) -> None: @@ -983,7 +983,7 @@ class Trainer: if self.num_gpus > 1: # let all processes sync up before starting with a new epoch of training dist.barrier() - self.callbacks.on_epoch_start() + self.callbacks.on_epoch_start(self) self.keep_avg_train = KeepAverage() self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.epochs_done = epoch @@ -999,7 +999,7 @@ class Trainer: ) if self.args.rank in [None, 0]: self.save_best_model() - self.callbacks.on_epoch_end() + self.callbacks.on_epoch_end(self) def fit(self) -> None: """Where the ✨️magic✨️ happens...""" @@ -1008,7 +1008,7 @@ class Trainer: if self.args.rank == 0: self.dashboard_logger.finish() except KeyboardInterrupt: - self.callbacks.on_keyboard_interrupt() + self.callbacks.on_keyboard_interrupt(self) # if the output folder is empty remove the run. remove_experiment_folder(self.output_path) # clear the DDP processes diff --git a/TTS/utils/callbacks.py b/TTS/utils/callbacks.py index 18b6c34c..3746eb15 100644 --- a/TTS/utils/callbacks.py +++ b/TTS/utils/callbacks.py @@ -1,75 +1,101 @@ class TrainerCallback: - def __init__(self, trainer): + def __init__(self): super().__init__() - self.trainer = trainer - def on_init_start(self) -> None: - if hasattr(self.trainer.model, "on_init_start"): - self.trainer.model.on_init_start(self.trainer) + def on_init_start(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_start"): + trainer.model.module.on_init_start(trainer) + else: + if hasattr(trainer.model, "on_init_start"): + trainer.model.on_init_start(trainer) - if hasattr(self.trainer.criterion, "on_init_start"): - self.trainer.criterion.on_init_start(self.trainer) + if hasattr(trainer.criterion, "on_init_start"): + trainer.criterion.on_init_start(trainer) - if hasattr(self.trainer.optimizer, "on_init_start"): - self.trainer.optimizer.on_init_start(self.trainer) + if hasattr(trainer.optimizer, "on_init_start"): + trainer.optimizer.on_init_start(trainer) - def on_init_end(self) -> None: - if hasattr(self.trainer.model, "on_init_end"): - self.trainer.model.on_init_end(self.trainer) + def on_init_end(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_end"): + trainer.model.module.on_init_end(trainer) + else: + if hasattr(trainer.model, "on_init_end"): + trainer.model.on_init_end(trainer) - if hasattr(self.trainer.criterion, "on_init_end"): - self.trainer.criterion.on_init_end(self.trainer) + if hasattr(trainer.criterion, "on_init_end"): + trainer.criterion.on_init_end(trainer) - if hasattr(self.trainer.optimizer, "on_init_end"): - self.trainer.optimizer.on_init_end(self.trainer) + if hasattr(trainer.optimizer, "on_init_end"): + trainer.optimizer.on_init_end(trainer) - def on_epoch_start(self) -> None: - if hasattr(self.trainer.model, "on_epoch_start"): - self.trainer.model.on_epoch_start(self.trainer) + def on_epoch_start(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_start"): + trainer.model.module.on_epoch_start(trainer) + else: + if hasattr(trainer.model, "on_epoch_start"): + trainer.model.on_epoch_start(trainer) - if hasattr(self.trainer.criterion, "on_epoch_start"): - self.trainer.criterion.on_epoch_start(self.trainer) + if hasattr(trainer.criterion, "on_epoch_start"): + trainer.criterion.on_epoch_start(trainer) - if hasattr(self.trainer.optimizer, "on_epoch_start"): - self.trainer.optimizer.on_epoch_start(self.trainer) + if hasattr(trainer.optimizer, "on_epoch_start"): + trainer.optimizer.on_epoch_start(trainer) - def on_epoch_end(self) -> None: - if hasattr(self.trainer.model, "on_epoch_end"): - self.trainer.model.on_epoch_end(self.trainer) + def on_epoch_end(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_end"): + trainer.model.module.on_epoch_end(trainer) + else: + if hasattr(trainer.model, "on_epoch_end"): + trainer.model.on_epoch_end(trainer) - if hasattr(self.trainer.criterion, "on_epoch_end"): - self.trainer.criterion.on_epoch_end(self.trainer) + if hasattr(trainer.criterion, "on_epoch_end"): + trainer.criterion.on_epoch_end(trainer) - if hasattr(self.trainer.optimizer, "on_epoch_end"): - self.trainer.optimizer.on_epoch_end(self.trainer) + if hasattr(trainer.optimizer, "on_epoch_end"): + trainer.optimizer.on_epoch_end(trainer) - def on_train_step_start(self) -> None: - if hasattr(self.trainer.model, "on_train_step_start"): - self.trainer.model.on_train_step_start(self.trainer) + def on_train_step_start(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_start"): + trainer.model.module.on_train_step_start(trainer) + else: + if hasattr(trainer.model, "on_train_step_start"): + trainer.model.on_train_step_start(trainer) - if hasattr(self.trainer.criterion, "on_train_step_start"): - self.trainer.criterion.on_train_step_start(self.trainer) + if hasattr(trainer.criterion, "on_train_step_start"): + trainer.criterion.on_train_step_start(trainer) - if hasattr(self.trainer.optimizer, "on_train_step_start"): - self.trainer.optimizer.on_train_step_start(self.trainer) + if hasattr(trainer.optimizer, "on_train_step_start"): + trainer.optimizer.on_train_step_start(trainer) - def on_train_step_end(self) -> None: + def on_train_step_end(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_end"): + trainer.model.module.on_train_step_end(trainer) + else: + if hasattr(trainer.model, "on_train_step_end"): + trainer.model.on_train_step_end(trainer) - if hasattr(self.trainer.model, "on_train_step_end"): - self.trainer.model.on_train_step_end(self.trainer) + if hasattr(trainer.criterion, "on_train_step_end"): + trainer.criterion.on_train_step_end(trainer) - if hasattr(self.trainer.criterion, "on_train_step_end"): - self.trainer.criterion.on_train_step_end(self.trainer) + if hasattr(trainer.optimizer, "on_train_step_end"): + trainer.optimizer.on_train_step_end(trainer) - if hasattr(self.trainer.optimizer, "on_train_step_end"): - self.trainer.optimizer.on_train_step_end(self.trainer) + def on_keyboard_interrupt(self, trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_keyboard_interrupt"): + trainer.model.module.on_keyboard_interrupt(trainer) + else: + if hasattr(trainer.model, "on_keyboard_interrupt"): + trainer.model.on_keyboard_interrupt(trainer) - def on_keyboard_interrupt(self) -> None: - if hasattr(self.trainer.model, "on_keyboard_interrupt"): - self.trainer.model.on_keyboard_interrupt(self.trainer) + if hasattr(trainer.criterion, "on_keyboard_interrupt"): + trainer.criterion.on_keyboard_interrupt(trainer) - if hasattr(self.trainer.criterion, "on_keyboard_interrupt"): - self.trainer.criterion.on_keyboard_interrupt(self.trainer) - - if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"): - self.trainer.optimizer.on_keyboard_interrupt(self.trainer) + if hasattr(trainer.optimizer, "on_keyboard_interrupt"): + trainer.optimizer.on_keyboard_interrupt(trainer)