diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index cbb4bf97..319c2598 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -176,6 +176,8 @@ class MyDataset(Dataset): if isinstance(batch[0], collections.Mapping): text_lenghts = np.array([len(d["text"]) for d in batch]) + + # sort items with text input length for RNN efficiency text_lenghts, ids_sorted_decreasing = torch.sort( torch.LongTensor(text_lenghts), dim=0, descending=True) @@ -187,6 +189,7 @@ class MyDataset(Dataset): speaker_name = [batch[idx]['speaker_name'] for idx in ids_sorted_decreasing] + # compute features mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav] @@ -211,7 +214,7 @@ class MyDataset(Dataset): assert mel.shape[2] == linear.shape[2] timesteps = mel.shape[2] - # B x T x D + # B x D x T --> B x T x D linear = linear.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1)