Fix TTSDataset

This commit is contained in:
Eren Gölge 2022-03-01 07:57:48 +01:00
parent 690de1ab06
commit 6a9f8074f0
1 changed files with 19 additions and 18 deletions

View File

@ -200,8 +200,8 @@ class TTSDataset(Dataset):
def get_f0(self, idx):
out_dict = self.f0_dataset[idx]
_, wav_file, *_ = _parse_sample(self.samples[idx])
assert wav_file == out_dict["audio_file"]
item = self.samples[idx]
assert item["audio_file"] == out_dict["audio_file"]
return out_dict
@staticmethod
@ -263,10 +263,11 @@ class TTSDataset(Dataset):
def _compute_lengths(samples):
new_samples = []
for item in samples:
text, wav_file, *_ = _parse_sample(item)
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
text_lenght = len(text)
new_samples += [item + [audio_length, text_lenght]]
audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio
text_lenght = len(item["text"])
item["audio_length"] = audio_length
item["text_length"] = text_lenght
new_samples += [item]
return new_samples
@staticmethod
@ -284,7 +285,7 @@ class TTSDataset(Dataset):
@staticmethod
def sort_by_length(samples: List[List]):
audio_lengths = [s[-2] for s in samples]
audio_lengths = [s["audio_length"] for s in samples]
idxs = np.argsort(audio_lengths) # ascending order
return idxs
@ -313,8 +314,8 @@ class TTSDataset(Dataset):
samples = self._compute_lengths(self.samples)
# sort items based on the sequence length in ascending order
text_lengths = [i[-1] for i in samples]
audio_lengths = [i[-2] for i in samples]
text_lengths = [i["text_length"] for i in samples]
audio_lengths = [i["audio_length"] for i in samples]
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
@ -341,9 +342,9 @@ class TTSDataset(Dataset):
samples = self.create_buckets(samples, self.batch_group_size)
# update items to the new sorted items
audio_lengths = [s[-2] for s in samples]
text_lengths = [s[-1] for s in samples]
self.samples = [s[:-2] for s in samples]
audio_lengths = [s["audio_length"] for s in samples]
text_lengths = [s["text_length"] for s in samples]
self.samples = samples
if self.verbose:
print(" | > Preprocessing samples")
@ -558,10 +559,10 @@ class PhonemeDataset(Dataset):
self.precompute(precompute_num_workers)
def __getitem__(self, index):
text, wav_file, *_ = _parse_sample(self.samples[index])
ids = self.compute_or_load(wav_file, text)
item = self.samples[index]
ids = self.compute_or_load(item["audio_file"], item["text"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
@ -667,12 +668,12 @@ class F0Dataset:
self.load_stats(cache_path)
def __getitem__(self, idx):
_, wav_file, *_ = _parse_sample(self.samples[idx])
f0 = self.compute_or_load(wav_file)
item = self.samples[idx]
f0 = self.compute_or_load(item["audio_file"])
if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0)
return {"audio_file": wav_file, "f0": f0}
return {"audio_file": item["audio_file"], "f0": f0}
def __len__(self):
return len(self.samples)