optional loss masking for stoptoken predictor

This commit is contained in:
erogol 2020-10-28 18:40:54 +01:00
parent e49cc3bbcd
commit fdaed45f58
1 changed files with 11 additions and 5 deletions

View File

@ -163,14 +163,20 @@ class BCELossMasked(nn.Module):
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).float()
if length is not None:
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).float()
x = x * mask
target = target * mask
num_items = mask.sum()
else:
num_items = torch.numel(x)
loss = functional.binary_cross_entropy_with_logits(
x * mask,
target * mask,
x,
target,
pos_weight=self.pos_weight,
reduction='sum')
loss = loss / mask.sum()
loss = loss / num_items
return loss