diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 0ee99930..b1cf886b 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E): scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) return [scheduler_D, scheduler_G] - def on_train_step_start(self, trainer): - """Schedule binary loss weight.""" - self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 - def on_epoch_end(self, trainer): # pylint: disable=unused-argument # stop updating mean and var # TODO: do the same for F0 @@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E): Args: trainer (Trainer): Trainer object. """ + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 self.train_disc = ( # pylint: disable=attribute-defined-outside-init trainer.total_steps_done >= self.config.steps_to_start_discriminator )