faster sequence masking

This commit is contained in:
erogol 2020-08-06 12:33:49 +02:00
parent 673ba74a80
commit 1dea2c9034
1 changed files with 4 additions and 8 deletions

View File

@ -34,15 +34,11 @@ def split_dataset(items):
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.to(sequence_length.device)
seq_length_expand = (
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
seq_range = torch.arange(max_len,
dtype=sequence_length.dtype,
device=sequence_length.device)
# B x T_max
return seq_range_expand < seq_length_expand
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
def to_camel(text):