Fix grad_norm handling

This commit is contained in:
Eren Gölge 2021-10-21 16:16:50 +00:00
parent a409e0f8f8
commit 70e4d0e524
1 changed files with 1 additions and 1 deletions

View File

@ -647,7 +647,7 @@ class Trainer:
optimizer.step()
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if isinstance(grad_norm ,torch.Tensor) and (torch.isnan(grad_norm) or torch.isinf(grad_norm)):
if isinstance(grad_norm, torch.Tensor) and (torch.isnan(grad_norm) or torch.isinf(grad_norm)):
grad_norm = 0
step_time = time.time() - step_start_time