diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 9685f463..c81fde49 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -99,7 +99,7 @@ def sequence_mask(sequence_length, max_len=None): seq_range = torch.arange(0, max_len).long() seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.cuda() + seq_range_expand = seq_range_expand.to(sequence_length.device) seq_length_expand = ( sequence_length.unsqueeze(1).expand_as(seq_range_expand)) # B x T_max