diff --git a/train.py b/train.py index 1a175043..155a9571 100644 --- a/train.py +++ b/train.py @@ -389,6 +389,8 @@ def main(args): criterion.cuda() criterion_st.cuda() optimizer.load_state_dict(checkpoint['optimizer']) + for group in optimizer.param_groups: + group['lr'] = c.lr print( " > Model restored from step %d" % checkpoint['step'], flush=True) start_epoch = checkpoint['epoch']