From 523fa5dfd2c183ad030bbcbf4ea37f0a3ae82653 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 18 May 2020 11:35:19 +0200 Subject: [PATCH] pass sequence mask to the same device as the input --- utils/generic_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 9685f463..c81fde49 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -99,7 +99,7 @@ def sequence_mask(sequence_length, max_len=None): seq_range = torch.arange(0, max_len).long() seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.cuda() + seq_range_expand = seq_range_expand.to(sequence_length.device) seq_length_expand = ( sequence_length.unsqueeze(1).expand_as(seq_range_expand)) # B x T_max