diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index b0a010b0..32513377 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -57,40 +57,61 @@ def sequence_mask(sequence_length, max_len=None): return mask -def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): """Segment each sample in a batch based on the provided segment indices Args: x (torch.tensor): Input tensor. segment_indices (torch.tensor): Segment indices. segment_size (int): Expected output segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. """ + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + segments = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): index_start = segment_indices[i] index_end = index_start + segment_size - segments[i] = x[i, :, index_start:index_end] + x_i = x[i] + if pad_short and index_end > x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] return segments -def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4): +def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False): """Create random segments based on the input lengths. Args: x (torch.tensor): Input tensor. x_lengths (torch.tensor): Input lengths. segment_size (int): Expected output segment size. + let_short_samples (bool): Allow shorter samples than the segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. Shapes: - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` """ + _x_lenghts = x_lengths.clone() B, _, T = x.size() - if x_lengths is None: - x_lengths = T - max_idxs = x_lengths - segment_size + 1 - assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." - segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + 1 + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + 1 + else: + assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() ret = segment(x, segment_indices, segment_size) return ret, segment_indices diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index 6a2f260d..708ecbf5 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -1,6 +1,6 @@ import torch as T -from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask +from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments def average_over_durations_test(): # pylint: disable=no-self-use @@ -39,6 +39,34 @@ def segment_test(): for idx, start_indx in enumerate(segment_ids): assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum() + try: + segments = segment(x, segment_ids, segment_size=10) + raise Exception("Should have failed") + except: + pass + + segments = segment(x, segment_ids, segment_size=10, pad_short=True) + for idx, start_indx in enumerate(segment_ids): + assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum() + + +def rand_segments_test(): + x = T.rand(2, 3, 4) + x_lens = T.randint(3, 4, (2,)) + segments, seg_idxs = rand_segments(x, x_lens, segment_size=3) + assert segments.shape == (2, 3, 3) + assert all(seg_idxs >= 0), seg_idxs + try: + segments, _ = rand_segments(x, x_lens, segment_size=5) + raise Exception("Should have failed") + except: + pass + x_lens_back = x_lens.clone() + segments, seg_idxs= rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True) + assert segments.shape == (2, 3, 5) + assert all(seg_idxs >= 0), seg_idxs + assert all(x_lens_back == x_lens) + def generate_path_test(): durations = T.randint(1, 4, (10, 21))