mirror of https://github.com/coqui-ai/TTS.git
Merge duplicate on_train_step_start functions in delightful_tts
This commit is contained in:
parent
861c68b0b8
commit
33a7c722f6
|
@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler_D, scheduler_G]
|
||||
|
||||
def on_train_step_start(self, trainer):
|
||||
"""Schedule binary loss weight."""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
|
||||
def on_epoch_end(self, trainer): # pylint: disable=unused-argument
|
||||
# stop updating mean and var
|
||||
# TODO: do the same for F0
|
||||
|
@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
Args:
|
||||
trainer (Trainer): Trainer object.
|
||||
"""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
self.train_disc = ( # pylint: disable=attribute-defined-outside-init
|
||||
trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue