check nan loss in glow-tts loss

This commit is contained in:
erogol 2020-11-02 13:12:19 +01:00
parent b8ac9aba9d
commit a108d0ee81
1 changed files with 5 additions and 0 deletions

View File

@ -402,4 +402,9 @@ class GlowTTSLoss(torch.nn.Module):
return_dict['loss'] = log_mle + loss_dur
return_dict['log_mle'] = log_mle
return_dict['loss_dur'] = loss_dur
# check if any loss is NaN
for key, loss in return_dict.items():
if torch.isnan(loss):
raise RuntimeError(f" [!] NaN loss with {key}.")
return return_dict