diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 9ccb5d62..1366c4a6 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -57,7 +57,7 @@ def sequence_mask(sequence_length, max_len=None): return mask -def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): """Segment each sample in a batch based on the provided segment indices Args: @@ -66,16 +66,25 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): segment_size (int): Expected output segment size. pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. """ - ret = torch.zeros_like(x[:, :, :segment_size]) + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + + segments = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): - idx_str = segment_indices[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret + index_start = segment_indices[i] + index_end = index_start + segment_size + x_i = x[i] + if pad_short and index_end > x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] + return segments def rand_segments( - x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4 + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False ): """Create random segments based on the input lengths. @@ -90,16 +99,25 @@ def rand_segments( - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` """ - b, _, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - if (ids_str_max < 0).sum(): - raise ValueError("Segment size is larger than the input length.") - ids_str = (torch.rand([b]).to(x.device) * ids_str_max).long() - ret = segment(x, ids_str, segment_size) - return ret, ids_str - + _x_lenghts = x_lengths.clone() + B, _, T = x.size() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + 1 + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + 1 + else: + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() + ret = segment(x, segment_indices, segment_size) + return ret, segment_indices def average_over_durations(values, durs): """Average values over durations.