From ab8a4ca2c3dfac43374922777153fe42f04002c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 5 Feb 2022 20:29:40 +0100 Subject: [PATCH] Revert random segment --- TTS/tts/models/vits.py | 3 +-- TTS/tts/utils/helpers.py | 51 +++++++++++++--------------------------- 2 files changed, 17 insertions(+), 37 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 256ea3af..751187ea 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -656,14 +656,13 @@ class Vits(BaseTTS): logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, - pad_short=True, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index c2e7f561..9ccb5d62 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -57,7 +57,7 @@ def sequence_mask(sequence_length, max_len=None): return mask -def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4): """Segment each sample in a batch based on the provided segment indices Args: @@ -66,25 +66,16 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ segment_size (int): Expected output segment size. pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. """ - # pad the input tensor if it is shorter than the segment size - if pad_short and x.shape[-1] < segment_size: - x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) - - segments = torch.zeros_like(x[:, :, :segment_size]) - + ret = torch.zeros_like(x[:, :, :segment_size]) for i in range(x.size(0)): - index_start = segment_indices[i] - index_end = index_start + segment_size - x_i = x[i] - if pad_short and index_end > x.size(2): - # pad the sample if it is shorter than the segment size - x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) - segments[i] = x_i[:, index_start:index_end] - return segments + idx_str = segment_indices[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret def rand_segments( - x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4 ): """Create random segments based on the input lengths. @@ -99,25 +90,15 @@ def rand_segments( - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` """ - _x_lenghts = x_lengths.clone() - B, _, T = x.size() - if pad_short: - if T < segment_size: - x = torch.nn.functional.pad(x, (0, segment_size - T)) - T = segment_size - if _x_lenghts is None: - _x_lenghts = T - len_diff = _x_lenghts - segment_size + 1 - if let_short_samples: - _x_lenghts[len_diff < 0] = segment_size - len_diff = _x_lenghts - segment_size + 1 - else: - assert all( - len_diff > 0 - ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" - segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() - ret = segment(x, segment_indices, segment_size) - return ret, segment_indices + b, _, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + if (ids_str_max < 0).sum(): + raise ValueError("Segment size is larger than the input length.") + ids_str = (torch.rand([b]).to(x.device) * ids_str_max).long() + ret = segment(x, ids_str, segment_size) + return ret, ids_str def average_over_durations(values, durs):