Fix colliding dataset cache file names (#1994)

* Fix colliding dataset cache file names

* Remove unused code
This commit is contained in:
Edresson Casanova 2022-09-21 07:54:07 -03:00 committed by GitHub
parent 3faccbda97
commit d6ad9a05b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 17 additions and 12 deletions

View File

@ -1,3 +1,4 @@
import base64
import collections import collections
import os import os
import random import random
@ -34,6 +35,12 @@ def noise_augment_audio(wav):
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
def string2filename(string):
# generate a safe and reversible filename based on a string
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
return filename
class TTSDataset(Dataset): class TTSDataset(Dataset):
def __init__( def __init__(
self, self,
@ -201,7 +208,7 @@ class TTSDataset(Dataset):
def get_f0(self, idx): def get_f0(self, idx):
out_dict = self.f0_dataset[idx] out_dict = self.f0_dataset[idx]
item = self.samples[idx] item = self.samples[idx]
assert item["audio_file"] == out_dict["audio_file"] assert item["audio_unique_name"] == out_dict["audio_unique_name"]
return out_dict return out_dict
@staticmethod @staticmethod
@ -561,19 +568,18 @@ class PhonemeDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
item = self.samples[index] item = self.samples[index]
ids = self.compute_or_load(item["audio_file"], item["text"]) ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"])
ph_hat = self.tokenizer.ids_to_text(ids) ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def compute_or_load(self, wav_file, text): def compute_or_load(self, file_name, text):
"""Compute phonemes for the given text. """Compute phonemes for the given text.
If the phonemes are already cached, load them from cache. If the phonemes are already cached, load them from cache.
""" """
file_name = os.path.splitext(os.path.basename(wav_file))[0]
file_ext = "_phoneme.npy" file_ext = "_phoneme.npy"
cache_path = os.path.join(self.cache_path, file_name + file_ext) cache_path = os.path.join(self.cache_path, file_name + file_ext)
try: try:
@ -670,11 +676,11 @@ class F0Dataset:
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.samples[idx] item = self.samples[idx]
f0 = self.compute_or_load(item["audio_file"]) f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
if self.normalize_f0: if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0) f0 = self.normalize(f0)
return {"audio_file": item["audio_file"], "f0": f0} return {"audio_unique_name": item["audio_unique_name"], "f0": f0}
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
@ -706,8 +712,7 @@ class F0Dataset:
return self.pad_id return self.pad_id
@staticmethod @staticmethod
def create_pitch_file_path(wav_file, cache_path): def create_pitch_file_path(file_name, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file return pitch_file
@ -745,11 +750,11 @@ class F0Dataset:
pitch[zero_idxs] = 0.0 pitch[zero_idxs] = 0.0
return pitch return pitch
def compute_or_load(self, wav_file): def compute_or_load(self, wav_file, audio_unique_name):
""" """
compute pitch and return a numpy array of pitch values compute pitch and return a numpy array of pitch values
""" """
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
else: else:
@ -757,14 +762,14 @@ class F0Dataset:
return pitch.astype(np.float32) return pitch.astype(np.float32)
def collate_fn(self, batch): def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch] audio_unique_name = [item["audio_unique_name"] for item in batch]
f0s = [item["f0"] for item in batch] f0s = [item["f0"] for item in batch]
f0_lens = [len(item["f0"]) for item in batch] f0_lens = [len(item["f0"]) for item in batch]
f0_lens_max = max(f0_lens) f0_lens_max = max(f0_lens)
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
for i, f0_len in enumerate(f0_lens): for i, f0_len in enumerate(f0_lens):
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens}
def print_logs(self, level: int = 0) -> None: def print_logs(self, level: int = 0) -> None:
indent = "\t" * level indent = "\t" * level