diff --git a/TTS/trainer.py b/TTS/trainer.py index 68b45fe2..0b4ad308 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -268,9 +268,13 @@ class Trainer: self.config, args.restore_path, self.model, self.optimizer, self.scaler ) - # setup scheduler + + # setup scheduler self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) + if self.args.continue_path: + self.scheduler.last_epoch = self.restore_step + # DISTRUBUTED if self.num_gpus > 1: self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) @@ -291,7 +295,6 @@ class Trainer: Returns: nn.Module: initialized model. """ - # TODO: better model setup try: model = setup_vocoder_model(config) except ModuleNotFoundError: