From 6a9f8074f09a5960e7bc270de69b593281553b06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 1 Mar 2022 07:57:48 +0100 Subject: [PATCH] Fix TTSDataset --- TTS/tts/datasets/dataset.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d4d1a7e5..d8f16e4e 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -200,8 +200,8 @@ class TTSDataset(Dataset): def get_f0(self, idx): out_dict = self.f0_dataset[idx] - _, wav_file, *_ = _parse_sample(self.samples[idx]) - assert wav_file == out_dict["audio_file"] + item = self.samples[idx] + assert item["audio_file"] == out_dict["audio_file"] return out_dict @staticmethod @@ -263,10 +263,11 @@ class TTSDataset(Dataset): def _compute_lengths(samples): new_samples = [] for item in samples: - text, wav_file, *_ = _parse_sample(item) - audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio - text_lenght = len(text) - new_samples += [item + [audio_length, text_lenght]] + audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio + text_lenght = len(item["text"]) + item["audio_length"] = audio_length + item["text_length"] = text_lenght + new_samples += [item] return new_samples @staticmethod @@ -284,7 +285,7 @@ class TTSDataset(Dataset): @staticmethod def sort_by_length(samples: List[List]): - audio_lengths = [s[-2] for s in samples] + audio_lengths = [s["audio_length"] for s in samples] idxs = np.argsort(audio_lengths) # ascending order return idxs @@ -313,8 +314,8 @@ class TTSDataset(Dataset): samples = self._compute_lengths(self.samples) # sort items based on the sequence length in ascending order - text_lengths = [i[-1] for i in samples] - audio_lengths = [i[-2] for i in samples] + text_lengths = [i["text_length"] for i in samples] + audio_lengths = [i["audio_length"] for i in samples] text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) @@ -341,9 +342,9 @@ class TTSDataset(Dataset): samples = self.create_buckets(samples, self.batch_group_size) # update items to the new sorted items - audio_lengths = [s[-2] for s in samples] - text_lengths = [s[-1] for s in samples] - self.samples = [s[:-2] for s in samples] + audio_lengths = [s["audio_length"] for s in samples] + text_lengths = [s["text_length"] for s in samples] + self.samples = samples if self.verbose: print(" | > Preprocessing samples") @@ -558,10 +559,10 @@ class PhonemeDataset(Dataset): self.precompute(precompute_num_workers) def __getitem__(self, index): - text, wav_file, *_ = _parse_sample(self.samples[index]) - ids = self.compute_or_load(wav_file, text) + item = self.samples[index] + ids = self.compute_or_load(item["audio_file"], item["text"]) ph_hat = self.tokenizer.ids_to_text(ids) - return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} def __len__(self): return len(self.samples) @@ -667,12 +668,12 @@ class F0Dataset: self.load_stats(cache_path) def __getitem__(self, idx): - _, wav_file, *_ = _parse_sample(self.samples[idx]) - f0 = self.compute_or_load(wav_file) + item = self.samples[idx] + f0 = self.compute_or_load(item["audio_file"]) if self.normalize_f0: assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" f0 = self.normalize(f0) - return {"audio_file": wav_file, "f0": f0} + return {"audio_file": item["audio_file"], "f0": f0} def __len__(self): return len(self.samples)