From 4326582bb1e68480ef79a02abbf4bfacc3aadede Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 6 Mar 2019 13:10:05 +0100 Subject: [PATCH] TTSDataset formatting and batch sorting to use pytorch pack for rnns --- datasets/TTSDataset.py | 47 +++++---- datasets/TTSDatasetMemory.py | 179 ----------------------------------- layers/losses.py | 22 ++--- 3 files changed, 37 insertions(+), 211 deletions(-) delete mode 100644 datasets/TTSDatasetMemory.py diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index 497b9b2c..30ead9ad 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -89,7 +89,8 @@ class MyDataset(Dataset): def load_phoneme_sequence(self, wav_file, text): file_name = os.path.basename(wav_file).split('.')[0] - tmp_path = os.path.join(self.phoneme_cache_path, file_name+'_phoneme.npy') + tmp_path = os.path.join(self.phoneme_cache_path, + file_name + '_phoneme.npy') if os.path.isfile(tmp_path): try: text = np.load(tmp_path) @@ -102,7 +103,9 @@ class MyDataset(Dataset): np.save(tmp_path, text) else: text = np.asarray( - phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language), dtype=np.int32) + phoneme_to_sequence( + text, [self.cleaners], language=self.phoneme_language), + dtype=np.int32) np.save(tmp_path, text) return text @@ -112,7 +115,7 @@ class MyDataset(Dataset): mel_name = self.items[idx][2] linear_name = self.items[idx][3] text = self.items[idx][0] - + if wav_name.split('.')[-1] == 'npy': wav = self.load_np(wav_name) else: @@ -124,13 +127,19 @@ class MyDataset(Dataset): wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) mel = None linear = None - + if self.use_phonemes: text = self.load_phoneme_sequence(wav_file, text) - else: + else: text = np.asarray( text_to_sequence(text, [self.cleaners]), dtype=np.int32) - sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear} + sample = { + 'text': text, + 'wav': wav, + 'item_idx': self.items[idx][1], + 'mel': mel, + 'linear': linear + } return sample def sort_items(self): @@ -151,9 +160,9 @@ class MyDataset(Dataset): for i in range(len(new_items) // self.batch_group_size): offset = i * self.batch_group_size end_offset = offset + self.batch_group_size - temp_items = new_items[offset : end_offset] + temp_items = new_items[offset:end_offset] random.shuffle(temp_items) - new_items[offset : end_offset] = temp_items + new_items[offset:end_offset] = temp_items self.items = new_items if self.verbose: @@ -181,19 +190,25 @@ class MyDataset(Dataset): # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.Mapping): - keys = list() - wav = [d['wav'] for d in batch] - item_idxs = [d['item_idx'] for d in batch] - text = [d['text'] for d in batch] + text_lenghts = np.array([len(d["text"]) for d in batch]) + text_lenghts, ids_sorted_decreasing = torch.sort( + torch.LongTensor(text_lenghts), dim=0, descending=True) - text_lenghts = np.array([len(x) for x in text]) - max_text_len = np.max(text_lenghts) + wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing] + item_idxs = [ + batch[idx]['item_idx'] for idx in ids_sorted_decreasing + ] + text = [batch[idx]['text'] for idx in ids_sorted_decreasing] # if specs are not computed, compute them. if batch[0]['mel'] is None and batch[0]['linear'] is None: - mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] - linear = [self.ap.spectrogram(w).astype('float32') for w in wav] + mel = [ + self.ap.melspectrogram(w).astype('float32') for w in wav + ] + linear = [ + self.ap.spectrogram(w).astype('float32') for w in wav + ] else: mel = [d['mel'] for d in batch] linear = [d['linear'] for d in batch] diff --git a/datasets/TTSDatasetMemory.py b/datasets/TTSDatasetMemory.py deleted file mode 100644 index 0c799f46..00000000 --- a/datasets/TTSDatasetMemory.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -import random -import numpy as np -import collections -import librosa -import torch -from tqdm import tqdm -from torch.utils.data import Dataset - -from utils.text import text_to_sequence -from datasets.preprocess import tts_cache -from utils.data import (prepare_data, pad_per_step, prepare_tensor, - prepare_stop_target) - - -class MyDataset(Dataset): - # TODO: Merge to TTSDataset.py, but it is not fast as it is supposed to be - def __init__(self, - root_path, - meta_file, - outputs_per_step, - text_cleaner, - ap, - batch_group_size=0, - min_seq_len=0, - **kwargs - ): - self.root_path = root_path - self.batch_group_size = batch_group_size - self.feat_dir = os.path.join(root_path, 'loader_data') - self.items = tts_cache(root_path, meta_file) - self.outputs_per_step = outputs_per_step - self.sample_rate = ap.sample_rate - self.cleaners = text_cleaner - self.min_seq_len = min_seq_len - self.wavs = None - self.mels = None - self.linears = None - print(" > Reading LJSpeech from - {}".format(root_path)) - print(" | > Number of instances : {}".format(len(self.items))) - self.sort_items() - self.fill_data() - - def fill_data(self): - if self.wavs is None and self.mels is None: - self.wavs = [] - self.mels = [] - self.linears = [] - self.texts = [] - for item in tqdm(self.items): - wav_file = item[0] - mel_file = item[1] - linear_file = item[2] - text = item[-1] - wav = self.load_np(wav_file) - mel = self.load_np(mel_file) - linear = self.load_np(linear_file) - self.wavs.append(wav) - self.mels.append(mel) - self.linears.append(linear) - self.texts.append(np.asarray( - text_to_sequence(text, [self.cleaners]), dtype=np.int32)) - print(" > Data loaded to memory") - - def load_wav(self, filename): - try: - audio = librosa.core.load(filename, sr=self.sample_rate) - return audio - except RuntimeError as e: - print(" !! Cannot read file : {}".format(filename)) - - def load_np(self, filename): - data = np.load(filename).astype('float32') - return data - - def sort_items(self): - r"""Sort text sequences in ascending order""" - lengths = np.array([len(ins[-1]) for ins in self.items]) - - print(" | > Max length sequence {}".format(np.max(lengths))) - print(" | > Min length sequence {}".format(np.min(lengths))) - print(" | > Avg length sequence {}".format(np.mean(lengths))) - - idxs = np.argsort(lengths) - new_frames = [] - ignored = [] - for i, idx in enumerate(idxs): - length = lengths[idx] - if length < self.min_seq_len: - ignored.append(idx) - else: - new_frames.append(self.items[idx]) - print(" | > {} instances are ignored by min_seq_len ({})".format( - len(ignored), self.min_seq_len)) - # shuffle batch groups - if self.batch_group_size > 0: - print(" | > Batch group shuffling is active.") - for i in range(len(new_frames) // self.batch_group_size): - offset = i * self.batch_group_size - end_offset = offset + self.batch_group_size - temp_frames = new_frames[offset : end_offset] - random.shuffle(temp_frames) - new_frames[offset : end_offset] = temp_frames - self.items = new_frames - - def __len__(self): - return len(self.items) - - def __getitem__(self, idx): - text = self.texts[idx] - wav = self.wavs[idx] - mel = self.mels[idx] - linear = self.linears[idx] - sample = { - 'text': text, - 'wav': wav, - 'item_idx': self.items[idx][0], - 'mel': mel, - 'linear': linear - } - return sample - - def collate_fn(self, batch): - r""" - Perform preprocessing and create a final data batch: - 1. PAD sequences with the longest sequence in the batch - 2. Convert Audio signal to Spectrograms. - 3. PAD sequences that can be divided by r. - 4. Convert Numpy to Torch tensors. - """ - - # Puts each data field into a tensor with outer dimension batch size - if isinstance(batch[0], collections.Mapping): - keys = list() - - wav = [d['wav'] for d in batch] - item_idxs = [d['item_idx'] for d in batch] - text = [d['text'] for d in batch] - mel = [d['mel'] for d in batch] - linear = [d['linear'] for d in batch] - - text_lenghts = np.array([len(x) for x in text]) - max_text_len = np.max(text_lenghts) - mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame - - # compute 'stop token' targets - stop_targets = [ - np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths - ] - - # PAD stop targets - stop_targets = prepare_stop_target(stop_targets, - self.outputs_per_step) - - # PAD sequences with largest length of the batch - text = prepare_data(text).astype(np.int32) - wav = prepare_data(wav) - - # PAD features with largest length + a zero frame - linear = prepare_tensor(linear, self.outputs_per_step) - mel = prepare_tensor(mel, self.outputs_per_step) - timesteps = mel.shape[2] - - # B x T x D - linear = linear.transpose(0, 2, 1) - mel = mel.transpose(0, 2, 1) - - # convert things to pytorch - text_lenghts = torch.LongTensor(text_lenghts) - text = torch.LongTensor(text) - linear = torch.FloatTensor(linear) - mel = torch.FloatTensor(mel) - mel_lengths = torch.LongTensor(mel_lengths) - stop_targets = torch.FloatTensor(stop_targets) - - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs - - raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}".format(type(batch[0])))) diff --git a/layers/losses.py b/layers/losses.py index 4be86424..4e8cea81 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -50,22 +50,12 @@ class MSELossMasked(nn.Module): Returns: loss: An average loss value masked by the length. """ - input = input.contiguous() - target = target.contiguous() - - # logits_flat: (batch * max_len, dim) - input = input.view(-1, input.shape[-1]) - # target_flat: (batch * max_len, dim) - target_flat = target.view(-1, target.shape[-1]) - # losses_flat: (batch * max_len, dim) - losses_flat = functional.mse_loss( - input, target_flat, size_average=False, reduce=False) - # losses: (batch, max_len, dim) - losses = losses_flat.view(*target.size()) - # mask: (batch, max_len, 1) mask = sequence_mask( - sequence_length=length, max_len=target.size(1)).unsqueeze(2) - losses = losses * mask.float() - loss = losses.sum() / (length.float().sum() * float(target.shape[2])) + sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + mask = mask.expand_as(input) + loss = functional.mse_loss( + input * mask, target * mask, reduction="sum") + loss = loss / mask.sum() return loss +