mirror of https://github.com/coqui-ai/TTS.git
faster sequence masking
This commit is contained in:
parent
673ba74a80
commit
1dea2c9034
|
@ -34,15 +34,11 @@ def split_dataset(items):
|
|||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
seq_range = torch.arange(max_len,
|
||||
dtype=sequence_length.dtype,
|
||||
device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
|
|
Loading…
Reference in New Issue