mirror of https://github.com/coqui-ai/TTS.git
make it optional to load linear specs in dataloader and fix tests
respectively
This commit is contained in:
parent
cbcdec83da
commit
3472a41255
|
@ -13,6 +13,7 @@ class MyDataset(Dataset):
|
|||
def __init__(self,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
compute_linear_spec,
|
||||
ap,
|
||||
meta_data,
|
||||
tp=None,
|
||||
|
@ -28,6 +29,7 @@ class MyDataset(Dataset):
|
|||
Args:
|
||||
outputs_per_step (int): number of time frames predicted per step.
|
||||
text_cleaner (str): text cleaner used for the dataset.
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
ap (TTS.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
batch_group_size (int): (0) range of batch randomization after sorting
|
||||
|
@ -47,6 +49,7 @@ class MyDataset(Dataset):
|
|||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.compute_linear_spec = compute_linear_spec
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
|
@ -193,7 +196,6 @@ class MyDataset(Dataset):
|
|||
|
||||
# compute features
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
|
@ -208,25 +210,29 @@ class MyDataset(Dataset):
|
|||
|
||||
# PAD sequences with longest instance in the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with longest instance
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
assert mel.shape[2] == linear.shape[2]
|
||||
|
||||
# B x D x T --> B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear).contiguous()
|
||||
mel = torch.FloatTensor(mel).contiguous()
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
assert mel.shape[1] == linear.shape[1]
|
||||
linear = torch.FloatTensor(linear).contiguous()
|
||||
else:
|
||||
linear = None
|
||||
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||
stop_targets, item_idxs
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
dataset = TTSDataset.MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=True,
|
||||
ap=self.ap,
|
||||
meta_data=items,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
|
|
1
train.py
1
train.py
|
@ -47,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
dataset = MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=True if c.model.lower() is 'tacotron' else False
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
|
|
Loading…
Reference in New Issue