masked loss

This commit is contained in:
Eren Golge 2018-03-22 21:13:33 -07:00
parent 32d9c734b2
commit 2617518d91
2 changed files with 3 additions and 3 deletions

View File

@ -47,5 +47,5 @@ def L1LossMasked(input, target, length):
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / length.float().sum()
return loss / input.shape[0]
loss = losses.sum() / (length.float().sum() * target.shape[2])
return loss

View File

@ -240,7 +240,7 @@ def evaluate(model, criterion, data_loader, current_step):
mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var)
# loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq],