From 1dea2c903419f8c2a82521a3913089e7c65c3e9f Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 6 Aug 2020 12:33:49 +0200 Subject: [PATCH] faster sequence masking --- TTS/tts/utils/generic_utils.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 393da12c..6d7ceb75 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -34,15 +34,11 @@ def split_dataset(items): def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() - batch_size = sequence_length.size(0) - seq_range = torch.arange(0, max_len).long() - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.to(sequence_length.device) - seq_length_expand = ( - sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + seq_range = torch.arange(max_len, + dtype=sequence_length.dtype, + device=sequence_length.device) # B x T_max - return seq_range_expand < seq_length_expand + return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) def to_camel(text):