diff --git a/vocoder/train.py b/vocoder/train.py index a8a8f011..091a6932 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -450,10 +450,14 @@ def main(args): # pylint: disable=redefined-outer-name weight_decay=0) # schedulers - scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) - scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) - scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) - scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) + scheduler_gen = None + scheduler_disc = None + if 'lr_scheduler_gen' in c: + scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) + scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) + if 'lr_scheduler_disc' in c: + scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) + scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) # setup criterion criterion_gen = GeneratorLoss(c)