Allow padding for shorter segments

This commit is contained in:
Eren Gölge 2022-01-21 15:27:41 +00:00
parent 47fbddc8d4
commit c4c471d61d
2 changed files with 58 additions and 9 deletions

View File

@ -57,40 +57,61 @@ def sequence_mask(sequence_length, max_len=None):
return mask
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
"""Segment each sample in a batch based on the provided segment indices
Args:
x (torch.tensor): Input tensor.
segment_indices (torch.tensor): Segment indices.
segment_size (int): Expected output segment size.
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
"""
# pad the input tensor if it is shorter than the segment size
if pad_short and x.shape[-1] < segment_size:
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
index_start = segment_indices[i]
index_end = index_start + segment_size
segments[i] = x[i, :, index_start:index_end]
x_i = x[i]
if pad_short and index_end > x.size(2):
# pad the sample if it is shorter than the segment size
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
segments[i] = x_i[:, index_start:index_end]
return segments
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False):
"""Create random segments based on the input lengths.
Args:
x (torch.tensor): Input tensor.
x_lengths (torch.tensor): Input lengths.
segment_size (int): Expected output segment size.
let_short_samples (bool): Allow shorter samples than the segment size.
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
Shapes:
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B]`
"""
_x_lenghts = x_lengths.clone()
B, _, T = x.size()
if x_lengths is None:
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
if pad_short:
if T < segment_size:
x = torch.nn.functional.pad(x, (0, segment_size - T))
T = segment_size
if _x_lenghts is None:
_x_lenghts = T
len_diff = _x_lenghts - segment_size + 1
if let_short_samples:
_x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size + 1
else:
assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices

View File

@ -1,6 +1,6 @@
import torch as T
from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask
from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments
def average_over_durations_test(): # pylint: disable=no-self-use
@ -39,6 +39,34 @@ def segment_test():
for idx, start_indx in enumerate(segment_ids):
assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum()
try:
segments = segment(x, segment_ids, segment_size=10)
raise Exception("Should have failed")
except:
pass
segments = segment(x, segment_ids, segment_size=10, pad_short=True)
for idx, start_indx in enumerate(segment_ids):
assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum()
def rand_segments_test():
x = T.rand(2, 3, 4)
x_lens = T.randint(3, 4, (2,))
segments, seg_idxs = rand_segments(x, x_lens, segment_size=3)
assert segments.shape == (2, 3, 3)
assert all(seg_idxs >= 0), seg_idxs
try:
segments, _ = rand_segments(x, x_lens, segment_size=5)
raise Exception("Should have failed")
except:
pass
x_lens_back = x_lens.clone()
segments, seg_idxs= rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True)
assert segments.shape == (2, 3, 5)
assert all(seg_idxs >= 0), seg_idxs
assert all(x_lens_back == x_lens)
def generate_path_test():
durations = T.randint(1, 4, (10, 21))