From 1932401e8d11115efa8eee0a80fa1b265b17fca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 09:28:48 +0000 Subject: [PATCH] Fix dataset preprocessing --- TTS/tts/datasets/dataset.py | 67 ++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index a1bb23c3..62e146e0 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,4 +1,5 @@ import collections +from email.mime import audio import os import random from typing import Dict, List, Union @@ -140,8 +141,6 @@ class TTSDataset(Dataset): self.pitch_computed = False self.tokenizer = tokenizer - self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples) - if self.tokenizer.use_phonemes: self.phoneme_dataset = PhonemeDataset( self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers @@ -253,16 +252,14 @@ class TTSDataset(Dataset): return sample @staticmethod - def compute_lengths(samples): - audio_lengths = [] - text_lengths = [] + def _compute_lengths(samples): + new_samples = [] for item in samples: text, wav_file, *_ = _parse_sample(item) - audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio - text_lengths.append(len(text)) - audio_lengths = np.array(audio_lengths) - text_lengths = np.array(text_lengths) - return audio_lengths, text_lengths + audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + text_lenght = len(text) + new_samples += [item + [audio_length, text_lenght]] + return new_samples @staticmethod def filter_by_length(lengths: List[int], min_len: int, max_len: int): @@ -278,8 +275,9 @@ class TTSDataset(Dataset): return ignore_idx, keep_idx @staticmethod - def sort_by_length(lengths: List[int]): - idxs = np.argsort(lengths) # ascending order + def sort_by_length(samples: List[List]): + audio_lengths = [s[-2] for s in samples] + idxs = np.argsort(audio_lengths) # ascending order return idxs @staticmethod @@ -293,39 +291,38 @@ class TTSDataset(Dataset): samples[offset:end_offset] = temp_items return samples - def select_samples_by_idx(self, idxs): - samples = [] - audio_lengths = [] - text_lengths = [] + def _select_samples_by_idx(self, idxs, samples): + samples_new = [] for idx in idxs: - samples.append(self.samples[idx]) - audio_lengths.append(self.audio_lengths[idx]) - text_lengths.append(self.text_lengths[idx]) - return samples, audio_lengths, text_lengths + samples_new.append(samples[idx]) + return samples_new def preprocess_samples(self): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. """ + samples = self._compute_lengths(self.samples) # sort items based on the sequence length in ascending order - text_ignore_idx, text_keep_idx = self.filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) + text_lengths = [i[-1] for i in samples] + audio_lengths = [i[-2] for i in samples] + text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) audio_ignore_idx, audio_keep_idx = self.filter_by_length( - self.audio_lengths, self.min_audio_len, self.max_audio_len + audio_lengths, self.min_audio_len, self.max_audio_len ) - keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) + keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) - samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx) + samples = self._select_samples_by_idx(keep_idx, samples) - sorted_idxs = self.sort_by_length(audio_lengths) + sorted_idxs = self.sort_by_length(samples) if self.start_by_longest: longest_idxs = sorted_idxs[-1] sorted_idxs[-1] = sorted_idxs[0] sorted_idxs[0] = longest_idxs - samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs) + samples = self._select_samples_by_idx(sorted_idxs, samples) if len(samples) == 0: raise RuntimeError(" [!] No samples left") @@ -337,19 +334,19 @@ class TTSDataset(Dataset): samples = self.create_buckets(samples, self.batch_group_size) # update items to the new sorted items - self.samples = samples - self.audio_lengths = audio_lengths - self.text_lengths = text_lengtsh + audio_lengths = [s[-2] for s in samples] + text_lengths = [s[-1] for s in samples] + self.samples = [s[:-2] for s in samples] if self.verbose: print(" | > Preprocessing samples") - print(" | > Max text length: {}".format(np.max(self.text_lengths))) - print(" | > Min text length: {}".format(np.min(self.text_lengths))) - print(" | > Avg text length: {}".format(np.mean(self.text_lengths))) + print(" | > Max text length: {}".format(np.max(text_lengths))) + print(" | > Min text length: {}".format(np.min(text_lengths))) + print(" | > Avg text length: {}".format(np.mean(text_lengths))) print(" | ") - print(" | > Max audio length: {}".format(np.max(self.audio_lengths))) - print(" | > Min audio length: {}".format(np.min(self.audio_lengths))) - print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths))) + print(" | > Max audio length: {}".format(np.max(audio_lengths))) + print(" | > Min audio length: {}".format(np.min(audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) print(f" | > Num. instances discarded samples: {len(ignore_idx)}") print(" | > Batch group size: {}.".format(self.batch_group_size))