diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index bf8a6df0..9c0764fb 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -497,12 +497,12 @@ def main(args): # pylint: disable=redefined-outer-name model_disc.load_state_dict(checkpoint['model_disc']) print(" > Restoring Discriminator Optimizer...") optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) - if 'scheduler' in checkpoint: + if 'scheduler' in checkpoint and scheduler_gen is not None: print(" > Restoring Generator LR Scheduler...") scheduler_gen.load_state_dict(checkpoint['scheduler']) # NOTE: Not sure if necessary scheduler_gen.optimizer = optimizer_gen - if 'scheduler_disc' in checkpoint: + if 'scheduler_disc' in checkpoint and scheduler_disc is not None: print(" > Restoring Discriminator LR Scheduler...") scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) scheduler_disc.optimizer = optimizer_disc