Revert random segment

This commit is contained in:
Eren Gölge 2022-02-05 20:29:40 +01:00
parent 8622226f3f
commit ab8a4ca2c3
2 changed files with 17 additions and 37 deletions

View File

@ -656,14 +656,13 @@ class Vits(BaseTTS):
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
# select a random feature segment for the waveform decoder # select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g) o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment( wav_seg = segment(
waveform, waveform,
slice_ids * self.config.audio.hop_length, slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length,
pad_short=True,
) )
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:

View File

@ -57,7 +57,7 @@ def sequence_mask(sequence_length, max_len=None):
return mask return mask
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
"""Segment each sample in a batch based on the provided segment indices """Segment each sample in a batch based on the provided segment indices
Args: Args:
@ -66,25 +66,16 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
segment_size (int): Expected output segment size. segment_size (int): Expected output segment size.
pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size.
""" """
# pad the input tensor if it is shorter than the segment size ret = torch.zeros_like(x[:, :, :segment_size])
if pad_short and x.shape[-1] < segment_size:
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
segments = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)): for i in range(x.size(0)):
index_start = segment_indices[i] idx_str = segment_indices[i]
index_end = index_start + segment_size idx_end = idx_str + segment_size
x_i = x[i] ret[i] = x[i, :, idx_str:idx_end]
if pad_short and index_end > x.size(2): return ret
# pad the sample if it is shorter than the segment size
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
segments[i] = x_i[:, index_start:index_end]
return segments
def rand_segments( def rand_segments(
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4
): ):
"""Create random segments based on the input lengths. """Create random segments based on the input lengths.
@ -99,25 +90,15 @@ def rand_segments(
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- x_lengths: :math:`[B]` - x_lengths: :math:`[B]`
""" """
_x_lenghts = x_lengths.clone() b, _, t = x.size()
B, _, T = x.size() if x_lengths is None:
if pad_short: x_lengths = t
if T < segment_size: ids_str_max = x_lengths - segment_size + 1
x = torch.nn.functional.pad(x, (0, segment_size - T)) if (ids_str_max < 0).sum():
T = segment_size raise ValueError("Segment size is larger than the input length.")
if _x_lenghts is None: ids_str = (torch.rand([b]).to(x.device) * ids_str_max).long()
_x_lenghts = T ret = segment(x, ids_str, segment_size)
len_diff = _x_lenghts - segment_size + 1 return ret, ids_str
if let_short_samples:
_x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size + 1
else:
assert all(
len_diff > 0
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices
def average_over_durations(values, durs): def average_over_durations(values, durs):