mirror of https://github.com/coqui-ai/TTS.git
Fix TTSDataset
This commit is contained in:
parent
690de1ab06
commit
6a9f8074f0
|
@ -200,8 +200,8 @@ class TTSDataset(Dataset):
|
|||
|
||||
def get_f0(self, idx):
|
||||
out_dict = self.f0_dataset[idx]
|
||||
_, wav_file, *_ = _parse_sample(self.samples[idx])
|
||||
assert wav_file == out_dict["audio_file"]
|
||||
item = self.samples[idx]
|
||||
assert item["audio_file"] == out_dict["audio_file"]
|
||||
return out_dict
|
||||
|
||||
@staticmethod
|
||||
|
@ -263,10 +263,11 @@ class TTSDataset(Dataset):
|
|||
def _compute_lengths(samples):
|
||||
new_samples = []
|
||||
for item in samples:
|
||||
text, wav_file, *_ = _parse_sample(item)
|
||||
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
||||
text_lenght = len(text)
|
||||
new_samples += [item + [audio_length, text_lenght]]
|
||||
audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio
|
||||
text_lenght = len(item["text"])
|
||||
item["audio_length"] = audio_length
|
||||
item["text_length"] = text_lenght
|
||||
new_samples += [item]
|
||||
return new_samples
|
||||
|
||||
@staticmethod
|
||||
|
@ -284,7 +285,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def sort_by_length(samples: List[List]):
|
||||
audio_lengths = [s[-2] for s in samples]
|
||||
audio_lengths = [s["audio_length"] for s in samples]
|
||||
idxs = np.argsort(audio_lengths) # ascending order
|
||||
return idxs
|
||||
|
||||
|
@ -313,8 +314,8 @@ class TTSDataset(Dataset):
|
|||
samples = self._compute_lengths(self.samples)
|
||||
|
||||
# sort items based on the sequence length in ascending order
|
||||
text_lengths = [i[-1] for i in samples]
|
||||
audio_lengths = [i[-2] for i in samples]
|
||||
text_lengths = [i["text_length"] for i in samples]
|
||||
audio_lengths = [i["audio_length"] for i in samples]
|
||||
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
|
||||
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
|
||||
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
|
||||
|
@ -341,9 +342,9 @@ class TTSDataset(Dataset):
|
|||
samples = self.create_buckets(samples, self.batch_group_size)
|
||||
|
||||
# update items to the new sorted items
|
||||
audio_lengths = [s[-2] for s in samples]
|
||||
text_lengths = [s[-1] for s in samples]
|
||||
self.samples = [s[:-2] for s in samples]
|
||||
audio_lengths = [s["audio_length"] for s in samples]
|
||||
text_lengths = [s["text_length"] for s in samples]
|
||||
self.samples = samples
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Preprocessing samples")
|
||||
|
@ -558,10 +559,10 @@ class PhonemeDataset(Dataset):
|
|||
self.precompute(precompute_num_workers)
|
||||
|
||||
def __getitem__(self, index):
|
||||
text, wav_file, *_ = _parse_sample(self.samples[index])
|
||||
ids = self.compute_or_load(wav_file, text)
|
||||
item = self.samples[index]
|
||||
ids = self.compute_or_load(item["audio_file"], item["text"])
|
||||
ph_hat = self.tokenizer.ids_to_text(ids)
|
||||
return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
|
||||
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
@ -667,12 +668,12 @@ class F0Dataset:
|
|||
self.load_stats(cache_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
_, wav_file, *_ = _parse_sample(self.samples[idx])
|
||||
f0 = self.compute_or_load(wav_file)
|
||||
item = self.samples[idx]
|
||||
f0 = self.compute_or_load(item["audio_file"])
|
||||
if self.normalize_f0:
|
||||
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||
f0 = self.normalize(f0)
|
||||
return {"audio_file": wav_file, "f0": f0}
|
||||
return {"audio_file": item["audio_file"], "f0": f0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
|
Loading…
Reference in New Issue