mirror of https://github.com/coqui-ai/TTS.git
120 lines
4.1 KiB
Python
120 lines
4.1 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class WaveRNNDataset(Dataset):
|
|
"""
|
|
WaveRNN Dataset searchs for all the wav files under root path
|
|
and converts them to acoustic features on the fly.
|
|
"""
|
|
|
|
def __init__(self,
|
|
ap,
|
|
items,
|
|
seq_len,
|
|
hop_len,
|
|
pad,
|
|
mode,
|
|
mulaw,
|
|
is_training=True,
|
|
verbose=False,
|
|
):
|
|
|
|
super(WaveRNNDataset, self).__init__()
|
|
self.ap = ap
|
|
self.compute_feat = not isinstance(items[0], (tuple, list))
|
|
self.item_list = items
|
|
self.seq_len = seq_len
|
|
self.hop_len = hop_len
|
|
self.mel_len = seq_len // hop_len
|
|
self.pad = pad
|
|
self.mode = mode
|
|
self.mulaw = mulaw
|
|
self.is_training = is_training
|
|
self.verbose = verbose
|
|
|
|
assert self.seq_len % self.hop_len == 0
|
|
|
|
def __len__(self):
|
|
return len(self.item_list)
|
|
|
|
def __getitem__(self, index):
|
|
item = self.load_item(index)
|
|
return item
|
|
|
|
def load_item(self, index):
|
|
"""
|
|
load (audio, feat) couple if feature_path is set
|
|
else compute it on the fly
|
|
"""
|
|
if self.compute_feat:
|
|
|
|
wavpath = self.item_list[index]
|
|
audio = self.ap.load_wav(wavpath)
|
|
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
|
if audio.shape[0] < min_audio_len:
|
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
|
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
|
mel = self.ap.melspectrogram(audio)
|
|
|
|
if self.mode in ["gauss", "mold"]:
|
|
x_input = audio
|
|
elif isinstance(self.mode, int):
|
|
x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
|
|
if self.mulaw else self.ap.quantize(audio, bits=self.mode))
|
|
else:
|
|
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
|
|
|
else:
|
|
|
|
wavpath, feat_path = self.item_list[index]
|
|
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
|
|
|
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
|
self.item_list[index] = self.item_list[index + 1]
|
|
feat_path = self.item_list[index]
|
|
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
|
if self.mode in ["gauss", "mold"]:
|
|
x_input = self.ap.load_wav(wavpath)
|
|
elif isinstance(self.mode, int):
|
|
x_input = np.load(feat_path.replace("/mel/", "/quant/"))
|
|
else:
|
|
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
|
|
|
return mel, x_input, wavpath
|
|
|
|
def collate(self, batch):
|
|
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
|
max_offsets = [x[0].shape[-1] -
|
|
(mel_win + 2 * self.pad) for x in batch]
|
|
|
|
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
|
sig_offsets = [(offset + self.pad) *
|
|
self.hop_len for offset in mel_offsets]
|
|
|
|
mels = [
|
|
x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
|
|
for i, x in enumerate(batch)
|
|
]
|
|
|
|
coarse = [
|
|
x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
|
|
for i, x in enumerate(batch)
|
|
]
|
|
|
|
mels = np.stack(mels).astype(np.float32)
|
|
if self.mode in ["gauss", "mold"]:
|
|
coarse = np.stack(coarse).astype(np.float32)
|
|
coarse = torch.FloatTensor(coarse)
|
|
x_input = coarse[:, : self.seq_len]
|
|
elif isinstance(self.mode, int):
|
|
coarse = np.stack(coarse).astype(np.int64)
|
|
coarse = torch.LongTensor(coarse)
|
|
x_input = (2 * coarse[:, : self.seq_len].float() /
|
|
(2 ** self.mode - 1.0) - 1.0)
|
|
y_coarse = coarse[:, 1:]
|
|
mels = torch.FloatTensor(mels)
|
|
return x_input, mels, y_coarse
|