mirror of https://github.com/coqui-ai/TTS.git
commenting TTSDataset.py
This commit is contained in:
parent
79cca4ac80
commit
ee788bc558
|
@ -176,6 +176,8 @@ class MyDataset(Dataset):
|
||||||
if isinstance(batch[0], collections.Mapping):
|
if isinstance(batch[0], collections.Mapping):
|
||||||
|
|
||||||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
text_lenghts = np.array([len(d["text"]) for d in batch])
|
||||||
|
|
||||||
|
# sort items with text input length for RNN efficiency
|
||||||
text_lenghts, ids_sorted_decreasing = torch.sort(
|
text_lenghts, ids_sorted_decreasing = torch.sort(
|
||||||
torch.LongTensor(text_lenghts), dim=0, descending=True)
|
torch.LongTensor(text_lenghts), dim=0, descending=True)
|
||||||
|
|
||||||
|
@ -187,6 +189,7 @@ class MyDataset(Dataset):
|
||||||
speaker_name = [batch[idx]['speaker_name']
|
speaker_name = [batch[idx]['speaker_name']
|
||||||
for idx in ids_sorted_decreasing]
|
for idx in ids_sorted_decreasing]
|
||||||
|
|
||||||
|
# compute features
|
||||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
|
||||||
|
@ -211,7 +214,7 @@ class MyDataset(Dataset):
|
||||||
assert mel.shape[2] == linear.shape[2]
|
assert mel.shape[2] == linear.shape[2]
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
# B x T x D
|
# B x D x T --> B x T x D
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue