diff --git a/train.py b/train.py index d24a4b6b..ed2085b9 100644 --- a/train.py +++ b/train.py @@ -602,7 +602,7 @@ def main(args): # pylint: disable=redefined-outer-name if num_gpus > 1: model = apply_gradient_allreduce(model) - if c.lr_decay: + if c.noam_schedule: scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)