commenting TTSDataset.py

This commit is contained in:
Eren Golge 2019-11-19 12:39:31 +01:00
parent 79cca4ac80
commit ee788bc558
1 changed files with 4 additions and 1 deletions

View File

@ -176,6 +176,8 @@ class MyDataset(Dataset):
if isinstance(batch[0], collections.Mapping): if isinstance(batch[0], collections.Mapping):
text_lenghts = np.array([len(d["text"]) for d in batch]) text_lenghts = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency
text_lenghts, ids_sorted_decreasing = torch.sort( text_lenghts, ids_sorted_decreasing = torch.sort(
torch.LongTensor(text_lenghts), dim=0, descending=True) torch.LongTensor(text_lenghts), dim=0, descending=True)
@ -187,6 +189,7 @@ class MyDataset(Dataset):
speaker_name = [batch[idx]['speaker_name'] speaker_name = [batch[idx]['speaker_name']
for idx in ids_sorted_decreasing] for idx in ids_sorted_decreasing]
# compute features
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
@ -211,7 +214,7 @@ class MyDataset(Dataset):
assert mel.shape[2] == linear.shape[2] assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2] timesteps = mel.shape[2]
# B x T x D # B x D x T --> B x T x D
linear = linear.transpose(0, 2, 1) linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1)