Fix colliding dataset cache file names (#1994)

* Fix colliding dataset cache file names

* Remove unused code
This commit is contained in:
Edresson Casanova 2022-09-21 07:54:07 -03:00 committed by GitHub
parent 3faccbda97
commit d6ad9a05b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 17 additions and 12 deletions

View File

@ -1,3 +1,4 @@
import base64
import collections
import os
import random
@ -34,6 +35,12 @@ def noise_augment_audio(wav):
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
def string2filename(string):
# generate a safe and reversible filename based on a string
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
return filename
class TTSDataset(Dataset):
def __init__(
self,
@ -201,7 +208,7 @@ class TTSDataset(Dataset):
def get_f0(self, idx):
out_dict = self.f0_dataset[idx]
item = self.samples[idx]
assert item["audio_file"] == out_dict["audio_file"]
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
return out_dict
@staticmethod
@ -561,19 +568,18 @@ class PhonemeDataset(Dataset):
def __getitem__(self, index):
item = self.samples[index]
ids = self.compute_or_load(item["audio_file"], item["text"])
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
def compute_or_load(self, wav_file, text):
def compute_or_load(self, file_name, text):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
"""
file_name = os.path.splitext(os.path.basename(wav_file))[0]
file_ext = "_phoneme.npy"
cache_path = os.path.join(self.cache_path, file_name + file_ext)
try:
@ -670,11 +676,11 @@ class F0Dataset:
def __getitem__(self, idx):
item = self.samples[idx]
f0 = self.compute_or_load(item["audio_file"])
f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0)
return {"audio_file": item["audio_file"], "f0": f0}
return {"audio_unique_name": item["audio_unique_name"], "f0": f0}
def __len__(self):
return len(self.samples)
@ -706,8 +712,7 @@ class F0Dataset:
return self.pad_id
@staticmethod
def create_pitch_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
def create_pitch_file_path(file_name, cache_path):
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file
@ -745,11 +750,11 @@ class F0Dataset:
pitch[zero_idxs] = 0.0
return pitch
def compute_or_load(self, wav_file):
def compute_or_load(self, wav_file, audio_unique_name):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
else:
@ -757,14 +762,14 @@ class F0Dataset:
return pitch.astype(np.float32)
def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
audio_unique_name = [item["audio_unique_name"] for item in batch]
f0s = [item["f0"] for item in batch]
f0_lens = [len(item["f0"]) for item in batch]
f0_lens_max = max(f0_lens)
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
for i, f0_len in enumerate(f0_lens):
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level