mirror of https://github.com/coqui-ai/TTS.git
Fix rand_segment edge case (input_len == seg_len - 1)
This commit is contained in:
parent
5094499eba
commit
7d8b1665c8
|
@ -76,7 +76,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
|
|||
index_start = segment_indices[i]
|
||||
index_end = index_start + segment_size
|
||||
x_i = x[i]
|
||||
if pad_short and index_end > x.size(2):
|
||||
if pad_short and index_end >= x.size(2):
|
||||
# pad the sample if it is shorter than the segment size
|
||||
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
||||
segments[i] = x_i[:, index_start:index_end]
|
||||
|
@ -107,16 +107,16 @@ def rand_segments(
|
|||
T = segment_size
|
||||
if _x_lenghts is None:
|
||||
_x_lenghts = T
|
||||
len_diff = _x_lenghts - segment_size + 1
|
||||
len_diff = _x_lenghts - segment_size
|
||||
if let_short_samples:
|
||||
_x_lenghts[len_diff < 0] = segment_size
|
||||
len_diff = _x_lenghts - segment_size + 1
|
||||
len_diff = _x_lenghts - segment_size
|
||||
else:
|
||||
assert all(
|
||||
len_diff > 0
|
||||
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
||||
ret = segment(x, segment_indices, segment_size)
|
||||
segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long()
|
||||
ret = segment(x, segment_indices, segment_size, pad_short=pad_short)
|
||||
return ret, segment_indices
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue