diff --git a/layers/losses.py b/layers/losses.py new file mode 100644 index 00000000..0fdc654e --- /dev/null +++ b/layers/losses.py @@ -0,0 +1,48 @@ +import torch +from torch import functional + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def _sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.range(0, max_len - 1).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_range_expand = Variable(seq_range_expand) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = (sequence_length.unsqueeze(1) + .expand_as(seq_range_expand)) + return seq_range_expand < seq_length_expand + + +def L1LossMasked(input, target, length): + """ + Args: + logits: A Variable containing a FloatTensor of size + (batch, max_len, num_classes) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Returns: + loss: An average loss value masked by the length. + """ + + # logits_flat: (batch * max_len, num_classes) + input = input.view(-1, input.size(-1)) + # target_flat: (batch * max_len, 1) + target_flat = target.view(-1, 1) + # losses_flat: (batch * max_len, 1) + losees_flat = functional.l1_loss(input, target, size_average=False, + reduce=False) + # losses: (batch, max_len) + losses = losses_flat.view(*target.size()) + # mask: (batch, max_len) + mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) + losses = losses * mask.float() + loss = losses.sum() / length.float().sum() + return loss \ No newline at end of file