diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index c2e7f561..b62004c8 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -76,7 +76,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ index_start = segment_indices[i] index_end = index_start + segment_size x_i = x[i] - if pad_short and index_end > x.size(2): + if pad_short and index_end >= x.size(2): # pad the sample if it is shorter than the segment size x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) segments[i] = x_i[:, index_start:index_end] @@ -107,16 +107,16 @@ def rand_segments( T = segment_size if _x_lenghts is None: _x_lenghts = T - len_diff = _x_lenghts - segment_size + 1 + len_diff = _x_lenghts - segment_size if let_short_samples: _x_lenghts[len_diff < 0] = segment_size - len_diff = _x_lenghts - segment_size + 1 + len_diff = _x_lenghts - segment_size else: assert all( len_diff > 0 ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" - segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() - ret = segment(x, segment_indices, segment_size) + segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long() + ret = segment(x, segment_indices, segment_size, pad_short=pad_short) return ret, segment_indices