Merge duplicate on_train_step_start functions in delightful_tts

This commit is contained in:
Aarni Koskela 2023-09-27 01:10:44 +03:00
parent 861c68b0b8
commit 33a7c722f6
1 changed files with 1 additions and 4 deletions

View File

@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E):
scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler_D, scheduler_G]
def on_train_step_start(self, trainer):
"""Schedule binary loss weight."""
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
def on_epoch_end(self, trainer): # pylint: disable=unused-argument
# stop updating mean and var
# TODO: do the same for F0
@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E):
Args:
trainer (Trainer): Trainer object.
"""
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
self.train_disc = ( # pylint: disable=attribute-defined-outside-init
trainer.total_steps_done >= self.config.steps_to_start_discriminator
)