diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 5d91d74d..435d2b10 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -157,7 +157,7 @@ def check_update(model, grad_clip, ignore_stopnet=False): grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) else: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - if np.isinf(grad_norm): + if np.isinf(grad_norm.cpu()): print(" | > Gradient is INF !!") skip_flag = True return grad_norm, skip_flag