diff --git a/train.py b/train.py index 693cef8a..524bacbf 100644 --- a/train.py +++ b/train.py @@ -134,7 +134,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # backpass and check the grad norm for spec losses loss.backward(retain_graph=True) optimizer, current_lr = weight_decay(optimizer, c.wd) - grad_norm, _ = check_update(model, 1.0) + grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() # backpass and check the grad norm for stop loss