diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index d0ba6332..12d27fae 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -73,6 +73,18 @@ class GANDataset(Dataset): item1 = self.load_item(idx) return item1 + def _pad_short_samples(self, audio, mel=None): + """Pad samples shorter than the output sequence length""" + if len(audio) < self.seq_len: + audio = np.pad(audio, (0, self.seq_len - len(audio)), + mode='constant', + constant_values=0.0) + + if mel is not None and mel.shape[1] < self.feat_frame_len: + pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0] + mel = np.pad(mel, ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), mode='constant', constant_values=pad_value.mean()) + return audio, mel + def shuffle_mapping(self): random.shuffle(self.G_to_D_mappings) @@ -87,11 +99,7 @@ class GANDataset(Dataset): audio, mel = self.cache[idx] else: audio = self.ap.load_wav(wavpath) - - if len(audio) < self.seq_len + self.pad_short: - audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \ - mode='constant', constant_values=0.0) - + audio, _ = self._pad_short_samples(audio) mel = self.ap.melspectrogram(audio) else: @@ -103,6 +111,7 @@ class GANDataset(Dataset): else: audio = self.ap.load_wav(wavpath) mel = np.load(feat_path) + audio, mel= self._pad_short_samples(audio, mel) # correct the audio length wrt padding applied in stft audio = np.pad(audio, (0, self.hop_len), mode="edge")