TTSDataset formatting and batch sorting to use pytorch pack for rnns

This commit is contained in:
Eren Golge 2019-03-06 13:10:05 +01:00
parent 007bef5c35
commit 4326582bb1
3 changed files with 37 additions and 211 deletions

View File

@ -89,7 +89,8 @@ class MyDataset(Dataset):
def load_phoneme_sequence(self, wav_file, text):
file_name = os.path.basename(wav_file).split('.')[0]
tmp_path = os.path.join(self.phoneme_cache_path, file_name+'_phoneme.npy')
tmp_path = os.path.join(self.phoneme_cache_path,
file_name + '_phoneme.npy')
if os.path.isfile(tmp_path):
try:
text = np.load(tmp_path)
@ -102,7 +103,9 @@ class MyDataset(Dataset):
np.save(tmp_path, text)
else:
text = np.asarray(
phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language), dtype=np.int32)
phoneme_to_sequence(
text, [self.cleaners], language=self.phoneme_language),
dtype=np.int32)
np.save(tmp_path, text)
return text
@ -112,7 +115,7 @@ class MyDataset(Dataset):
mel_name = self.items[idx][2]
linear_name = self.items[idx][3]
text = self.items[idx][0]
if wav_name.split('.')[-1] == 'npy':
wav = self.load_np(wav_name)
else:
@ -124,13 +127,19 @@ class MyDataset(Dataset):
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = None
linear = None
if self.use_phonemes:
text = self.load_phoneme_sequence(wav_file, text)
else:
else:
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear}
sample = {
'text': text,
'wav': wav,
'item_idx': self.items[idx][1],
'mel': mel,
'linear': linear
}
return sample
def sort_items(self):
@ -151,9 +160,9 @@ class MyDataset(Dataset):
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_items = new_items[offset : end_offset]
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset : end_offset] = temp_items
new_items[offset:end_offset] = temp_items
self.items = new_items
if self.verbose:
@ -181,19 +190,25 @@ class MyDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
text_lenghts = np.array([len(d["text"]) for d in batch])
text_lenghts, ids_sorted_decreasing = torch.sort(
torch.LongTensor(text_lenghts), dim=0, descending=True)
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing]
item_idxs = [
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
# if specs are not computed, compute them.
if batch[0]['mel'] is None and batch[0]['linear'] is None:
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [
self.ap.melspectrogram(w).astype('float32') for w in wav
]
linear = [
self.ap.spectrogram(w).astype('float32') for w in wav
]
else:
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]

View File

@ -1,179 +0,0 @@
import os
import random
import numpy as np
import collections
import librosa
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from datasets.preprocess import tts_cache
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
# TODO: Merge to TTSDataset.py, but it is not fast as it is supposed to be
def __init__(self,
root_path,
meta_file,
outputs_per_step,
text_cleaner,
ap,
batch_group_size=0,
min_seq_len=0,
**kwargs
):
self.root_path = root_path
self.batch_group_size = batch_group_size
self.feat_dir = os.path.join(root_path, 'loader_data')
self.items = tts_cache(root_path, meta_file)
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.wavs = None
self.mels = None
self.linears = None
print(" > Reading LJSpeech from - {}".format(root_path))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
self.fill_data()
def fill_data(self):
if self.wavs is None and self.mels is None:
self.wavs = []
self.mels = []
self.linears = []
self.texts = []
for item in tqdm(self.items):
wav_file = item[0]
mel_file = item[1]
linear_file = item[2]
text = item[-1]
wav = self.load_np(wav_file)
mel = self.load_np(mel_file)
linear = self.load_np(linear_file)
self.wavs.append(wav)
self.mels.append(mel)
self.linears.append(linear)
self.texts.append(np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32))
print(" > Data loaded to memory")
def load_wav(self, filename):
try:
audio = librosa.core.load(filename, sr=self.sample_rate)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def load_np(self, filename):
data = np.load(filename).astype('float32')
return data
def sort_items(self):
r"""Sort text sequences in ascending order"""
lengths = np.array([len(ins[-1]) for ins in self.items])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
# shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset]
random.shuffle(temp_frames)
new_frames[offset : end_offset] = temp_frames
self.items = new_frames
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
text = self.texts[idx]
wav = self.wavs[idx]
mel = self.mels[idx]
linear = self.linears[idx]
sample = {
'text': text,
'wav': wav,
'item_idx': self.items[idx][0],
'mel': mel,
'linear': linear
}
return sample
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
timesteps = mel.shape[2]
# B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))

View File

@ -50,22 +50,12 @@ class MSELossMasked(nn.Module):
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.mse_loss(
input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
mask = mask.expand_as(input)
loss = functional.mse_loss(
input * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
return loss