mirror of https://github.com/coqui-ai/TTS.git
Revert back again rand_segment
This commit is contained in:
parent
00c7600103
commit
c11944022d
|
@ -57,7 +57,7 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
|
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
|
||||||
"""Segment each sample in a batch based on the provided segment indices
|
"""Segment each sample in a batch based on the provided segment indices
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -66,16 +66,25 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
|
||||||
segment_size (int): Expected output segment size.
|
segment_size (int): Expected output segment size.
|
||||||
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
|
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
|
||||||
"""
|
"""
|
||||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
# pad the input tensor if it is shorter than the segment size
|
||||||
|
if pad_short and x.shape[-1] < segment_size:
|
||||||
|
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
|
||||||
|
|
||||||
|
segments = torch.zeros_like(x[:, :, :segment_size])
|
||||||
|
|
||||||
for i in range(x.size(0)):
|
for i in range(x.size(0)):
|
||||||
idx_str = segment_indices[i]
|
index_start = segment_indices[i]
|
||||||
idx_end = idx_str + segment_size
|
index_end = index_start + segment_size
|
||||||
ret[i] = x[i, :, idx_str:idx_end]
|
x_i = x[i]
|
||||||
return ret
|
if pad_short and index_end > x.size(2):
|
||||||
|
# pad the sample if it is shorter than the segment size
|
||||||
|
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
||||||
|
segments[i] = x_i[:, index_start:index_end]
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def rand_segments(
|
def rand_segments(
|
||||||
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4
|
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
|
||||||
):
|
):
|
||||||
"""Create random segments based on the input lengths.
|
"""Create random segments based on the input lengths.
|
||||||
|
|
||||||
|
@ -90,16 +99,25 @@ def rand_segments(
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
- x_lengths: :math:`[B]`
|
- x_lengths: :math:`[B]`
|
||||||
"""
|
"""
|
||||||
b, _, t = x.size()
|
_x_lenghts = x_lengths.clone()
|
||||||
if x_lengths is None:
|
B, _, T = x.size()
|
||||||
x_lengths = t
|
if pad_short:
|
||||||
ids_str_max = x_lengths - segment_size + 1
|
if T < segment_size:
|
||||||
if (ids_str_max < 0).sum():
|
x = torch.nn.functional.pad(x, (0, segment_size - T))
|
||||||
raise ValueError("Segment size is larger than the input length.")
|
T = segment_size
|
||||||
ids_str = (torch.rand([b]).to(x.device) * ids_str_max).long()
|
if _x_lenghts is None:
|
||||||
ret = segment(x, ids_str, segment_size)
|
_x_lenghts = T
|
||||||
return ret, ids_str
|
len_diff = _x_lenghts - segment_size + 1
|
||||||
|
if let_short_samples:
|
||||||
|
_x_lenghts[len_diff < 0] = segment_size
|
||||||
|
len_diff = _x_lenghts - segment_size + 1
|
||||||
|
else:
|
||||||
|
assert all(
|
||||||
|
len_diff > 0
|
||||||
|
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||||
|
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
||||||
|
ret = segment(x, segment_indices, segment_size)
|
||||||
|
return ret, segment_indices
|
||||||
|
|
||||||
def average_over_durations(values, durs):
|
def average_over_durations(values, durs):
|
||||||
"""Average values over durations.
|
"""Average values over durations.
|
||||||
|
|
Loading…
Reference in New Issue