mirror of https://github.com/coqui-ai/TTS.git
TTSDataset formatting and batch sorting to use pytorch pack for rnns
This commit is contained in:
parent
007bef5c35
commit
4326582bb1
|
@ -89,7 +89,8 @@ class MyDataset(Dataset):
|
|||
|
||||
def load_phoneme_sequence(self, wav_file, text):
|
||||
file_name = os.path.basename(wav_file).split('.')[0]
|
||||
tmp_path = os.path.join(self.phoneme_cache_path, file_name+'_phoneme.npy')
|
||||
tmp_path = os.path.join(self.phoneme_cache_path,
|
||||
file_name + '_phoneme.npy')
|
||||
if os.path.isfile(tmp_path):
|
||||
try:
|
||||
text = np.load(tmp_path)
|
||||
|
@ -102,7 +103,9 @@ class MyDataset(Dataset):
|
|||
np.save(tmp_path, text)
|
||||
else:
|
||||
text = np.asarray(
|
||||
phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language), dtype=np.int32)
|
||||
phoneme_to_sequence(
|
||||
text, [self.cleaners], language=self.phoneme_language),
|
||||
dtype=np.int32)
|
||||
np.save(tmp_path, text)
|
||||
return text
|
||||
|
||||
|
@ -112,7 +115,7 @@ class MyDataset(Dataset):
|
|||
mel_name = self.items[idx][2]
|
||||
linear_name = self.items[idx][3]
|
||||
text = self.items[idx][0]
|
||||
|
||||
|
||||
if wav_name.split('.')[-1] == 'npy':
|
||||
wav = self.load_np(wav_name)
|
||||
else:
|
||||
|
@ -124,13 +127,19 @@ class MyDataset(Dataset):
|
|||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
mel = None
|
||||
linear = None
|
||||
|
||||
|
||||
if self.use_phonemes:
|
||||
text = self.load_phoneme_sequence(wav_file, text)
|
||||
else:
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear}
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'item_idx': self.items[idx][1],
|
||||
'mel': mel,
|
||||
'linear': linear
|
||||
}
|
||||
return sample
|
||||
|
||||
def sort_items(self):
|
||||
|
@ -151,9 +160,9 @@ class MyDataset(Dataset):
|
|||
for i in range(len(new_items) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
end_offset = offset + self.batch_group_size
|
||||
temp_items = new_items[offset : end_offset]
|
||||
temp_items = new_items[offset:end_offset]
|
||||
random.shuffle(temp_items)
|
||||
new_items[offset : end_offset] = temp_items
|
||||
new_items[offset:end_offset] = temp_items
|
||||
self.items = new_items
|
||||
|
||||
if self.verbose:
|
||||
|
@ -181,19 +190,25 @@ class MyDataset(Dataset):
|
|||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
keys = list()
|
||||
|
||||
wav = [d['wav'] for d in batch]
|
||||
item_idxs = [d['item_idx'] for d in batch]
|
||||
text = [d['text'] for d in batch]
|
||||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
||||
text_lenghts, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor(text_lenghts), dim=0, descending=True)
|
||||
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing]
|
||||
item_idxs = [
|
||||
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
|
||||
]
|
||||
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
||||
|
||||
# if specs are not computed, compute them.
|
||||
if batch[0]['mel'] is None and batch[0]['linear'] is None:
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
mel = [
|
||||
self.ap.melspectrogram(w).astype('float32') for w in wav
|
||||
]
|
||||
linear = [
|
||||
self.ap.spectrogram(w).astype('float32') for w in wav
|
||||
]
|
||||
else:
|
||||
mel = [d['mel'] for d in batch]
|
||||
linear = [d['linear'] for d in batch]
|
||||
|
|
|
@ -1,179 +0,0 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from datasets.preprocess import tts_cache
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
# TODO: Merge to TTSDataset.py, but it is not fast as it is supposed to be
|
||||
def __init__(self,
|
||||
root_path,
|
||||
meta_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
**kwargs
|
||||
):
|
||||
self.root_path = root_path
|
||||
self.batch_group_size = batch_group_size
|
||||
self.feat_dir = os.path.join(root_path, 'loader_data')
|
||||
self.items = tts_cache(root_path, meta_file)
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.wavs = None
|
||||
self.mels = None
|
||||
self.linears = None
|
||||
print(" > Reading LJSpeech from - {}".format(root_path))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
self.sort_items()
|
||||
self.fill_data()
|
||||
|
||||
def fill_data(self):
|
||||
if self.wavs is None and self.mels is None:
|
||||
self.wavs = []
|
||||
self.mels = []
|
||||
self.linears = []
|
||||
self.texts = []
|
||||
for item in tqdm(self.items):
|
||||
wav_file = item[0]
|
||||
mel_file = item[1]
|
||||
linear_file = item[2]
|
||||
text = item[-1]
|
||||
wav = self.load_np(wav_file)
|
||||
mel = self.load_np(mel_file)
|
||||
linear = self.load_np(linear_file)
|
||||
self.wavs.append(wav)
|
||||
self.mels.append(mel)
|
||||
self.linears.append(linear)
|
||||
self.texts.append(np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32))
|
||||
print(" > Data loaded to memory")
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||
return audio
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def load_np(self, filename):
|
||||
data = np.load(filename).astype('float32')
|
||||
return data
|
||||
|
||||
def sort_items(self):
|
||||
r"""Sort text sequences in ascending order"""
|
||||
lengths = np.array([len(ins[-1]) for ins in self.items])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_frames = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.items[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
# shuffle batch groups
|
||||
if self.batch_group_size > 0:
|
||||
print(" | > Batch group shuffling is active.")
|
||||
for i in range(len(new_frames) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
end_offset = offset + self.batch_group_size
|
||||
temp_frames = new_frames[offset : end_offset]
|
||||
random.shuffle(temp_frames)
|
||||
new_frames[offset : end_offset] = temp_frames
|
||||
self.items = new_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
text = self.texts[idx]
|
||||
wav = self.wavs[idx]
|
||||
mel = self.mels[idx]
|
||||
linear = self.linears[idx]
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'item_idx': self.items[idx][0],
|
||||
'mel': mel,
|
||||
'linear': linear
|
||||
}
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. PAD sequences with the longest sequence in the batch
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences that can be divided by r.
|
||||
4. Convert Numpy to Torch tensors.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
keys = list()
|
||||
|
||||
wav = [d['wav'] for d in batch]
|
||||
item_idxs = [d['item_idx'] for d in batch]
|
||||
text = [d['text'] for d in batch]
|
||||
mel = [d['mel'] for d in batch]
|
||||
linear = [d['linear'] for d in batch]
|
||||
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with largest length + a zero frame
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
timesteps = mel.shape[2]
|
||||
|
||||
# B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
mel = torch.FloatTensor(mel)
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
|
@ -50,22 +50,12 @@ class MSELossMasked(nn.Module):
|
|||
Returns:
|
||||
loss: An average loss value masked by the length.
|
||||
"""
|
||||
input = input.contiguous()
|
||||
target = target.contiguous()
|
||||
|
||||
# logits_flat: (batch * max_len, dim)
|
||||
input = input.view(-1, input.shape[-1])
|
||||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.mse_loss(
|
||||
input, target_flat, size_average=False, reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = mask.expand_as(input)
|
||||
loss = functional.mse_loss(
|
||||
input * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
|
Loading…
Reference in New Issue