From c27fd4238afb7a7c7e059f9b4cedcf37146b2aa2 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 22 May 2020 13:09:07 +0200 Subject: [PATCH] update torch version and solve compat issue with isinf --- requirements.txt | 2 +- setup.py | 2 +- utils/training.py | 12 +++++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5f31db70..862cb229 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy>=1.16.0 -torch>=0.4.1 +torch>=1.5 librosa>=0.5.1 Unidecode>=0.4.20 tensorboard diff --git a/setup.py b/setup.py index 5e89723b..84a31488 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ setup( }, install_requires=[ "scipy>=0.19.0", - "torch>=0.4.1", + "torch>=1.5", "numpy>=1.16.0", "librosa==0.6.2", "unidecode==0.4.20", diff --git a/utils/training.py b/utils/training.py index ebf8fd13..5ce7948b 100644 --- a/utils/training.py +++ b/utils/training.py @@ -9,9 +9,15 @@ def check_update(model, grad_clip, ignore_stopnet=False): grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) else: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - if torch.isinf(grad_norm): - print(" | > Gradient is INF !!") - skip_flag = True + # compatibility with different torch versions + if isinstance(grad_norm, float): + if np.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + else: + if torch.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True return grad_norm, skip_flag