check against NaN loss in tacotron_loss

This commit is contained in:
erogol 2020-11-02 12:44:41 +01:00
parent ef04d7fae7
commit b8ac9aba9d
2 changed files with 9 additions and 4 deletions

View File

@ -1,14 +1,14 @@
#!/bin/bash
yes | apt-get install sox
yes | apt-get install ffmpeg
yes | apt-get install espeak
yes | apt-get install espeak
yes | apt-get install tmux
yes | apt-get install zsh
sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl
sudo sh install.sh
pip install pytorch==1.3.0+cu100
python3 setup.py develop
# pip install pytorch==1.7.0+cu100
# python3 setup.py develop
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/

View File

@ -228,7 +228,7 @@ class GuidedAttentionLoss(torch.nn.Module):
@staticmethod
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 /
(2 * (sigma**2)))
@ -373,6 +373,11 @@ class TacotronLoss(torch.nn.Module):
return_dict['postnet_ssim_loss'] = postnet_ssim_loss
return_dict['loss'] = loss
# check if any loss is NaN
for key, loss in return_dict.items():
if torch.isnan(loss):
raise RuntimeError(f" [!] NaN loss with {key}.")
return return_dict