optional loss masking for stoptoken predictor

This commit is contained in:
erogol 2020-10-28 18:40:54 +01:00
parent e49cc3bbcd
commit fdaed45f58
1 changed files with 11 additions and 5 deletions

View File

@ -163,14 +163,20 @@ class BCELossMasked(nn.Module):
""" """
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)
target.requires_grad = False target.requires_grad = False
mask = sequence_mask(sequence_length=length, if length is not None:
max_len=target.size(1)).float() mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).float()
x = x * mask
target = target * mask
num_items = mask.sum()
else:
num_items = torch.numel(x)
loss = functional.binary_cross_entropy_with_logits( loss = functional.binary_cross_entropy_with_logits(
x * mask, x,
target * mask, target,
pos_weight=self.pos_weight, pos_weight=self.pos_weight,
reduction='sum') reduction='sum')
loss = loss / mask.sum() loss = loss / num_items
return loss return loss