Fix rand_segment edge case (input_len == seg_len - 1)

This commit is contained in:
Eren G??lge 2022-08-01 11:37:45 +02:00
parent 5094499eba
commit 7d8b1665c8
1 changed files with 5 additions and 5 deletions

View File

@ -76,7 +76,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
index_start = segment_indices[i]
index_end = index_start + segment_size
x_i = x[i]
if pad_short and index_end > x.size(2):
if pad_short and index_end >= x.size(2):
# pad the sample if it is shorter than the segment size
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
segments[i] = x_i[:, index_start:index_end]
@ -107,16 +107,16 @@ def rand_segments(
T = segment_size
if _x_lenghts is None:
_x_lenghts = T
len_diff = _x_lenghts - segment_size + 1
len_diff = _x_lenghts - segment_size
if let_short_samples:
_x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size + 1
len_diff = _x_lenghts - segment_size
else:
assert all(
len_diff > 0
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
ret = segment(x, segment_indices, segment_size)
segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long()
ret = segment(x, segment_indices, segment_size, pad_short=pad_short)
return ret, segment_indices