make it optional to load linear specs in dataloader and fix tests

respectively
This commit is contained in:
erogol 2020-03-10 18:17:35 +01:00
parent cbcdec83da
commit 3472a41255
3 changed files with 14 additions and 6 deletions

View File

@ -13,6 +13,7 @@ class MyDataset(Dataset):
def __init__(self,
outputs_per_step,
text_cleaner,
compute_linear_spec,
ap,
meta_data,
tp=None,
@ -28,6 +29,7 @@ class MyDataset(Dataset):
Args:
outputs_per_step (int): number of time frames predicted per step.
text_cleaner (str): text cleaner used for the dataset.
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
batch_group_size (int): (0) range of batch randomization after sorting
@ -47,6 +49,7 @@ class MyDataset(Dataset):
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
@ -193,7 +196,6 @@ class MyDataset(Dataset):
# compute features
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel_lengths = [m.shape[1] for m in mel]
@ -208,25 +210,29 @@ class MyDataset(Dataset):
# PAD sequences with longest instance in the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with longest instance
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
# B x D x T --> B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear).contiguous()
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
linear = torch.FloatTensor(linear).contiguous()
else:
linear = None
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
stop_targets, item_idxs

View File

@ -36,6 +36,7 @@ class TestTTSDataset(unittest.TestCase):
dataset = TTSDataset.MyDataset(
r,
c.text_cleaner,
compute_linear_spec=True,
ap=self.ap,
meta_data=items,
tp=c.characters if 'characters' in c.keys() else None,

View File

@ -47,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
dataset = MyDataset(
r,
c.text_cleaner,
compute_linear_spec=True if c.model.lower() is 'tacotron' else False
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,