mirror of https://github.com/coqui-ai/TTS.git
Fix colliding dataset cache file names (#1994)
* Fix colliding dataset cache file names * Remove unused code
This commit is contained in:
parent
3faccbda97
commit
d6ad9a05b4
|
@ -1,3 +1,4 @@
|
|||
import base64
|
||||
import collections
|
||||
import os
|
||||
import random
|
||||
|
@ -34,6 +35,12 @@ def noise_augment_audio(wav):
|
|||
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
||||
|
||||
|
||||
def string2filename(string):
|
||||
# generate a safe and reversible filename based on a string
|
||||
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
|
||||
return filename
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -201,7 +208,7 @@ class TTSDataset(Dataset):
|
|||
def get_f0(self, idx):
|
||||
out_dict = self.f0_dataset[idx]
|
||||
item = self.samples[idx]
|
||||
assert item["audio_file"] == out_dict["audio_file"]
|
||||
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
|
||||
return out_dict
|
||||
|
||||
@staticmethod
|
||||
|
@ -561,19 +568,18 @@ class PhonemeDataset(Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
item = self.samples[index]
|
||||
ids = self.compute_or_load(item["audio_file"], item["text"])
|
||||
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"])
|
||||
ph_hat = self.tokenizer.ids_to_text(ids)
|
||||
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def compute_or_load(self, wav_file, text):
|
||||
def compute_or_load(self, file_name, text):
|
||||
"""Compute phonemes for the given text.
|
||||
|
||||
If the phonemes are already cached, load them from cache.
|
||||
"""
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
file_ext = "_phoneme.npy"
|
||||
cache_path = os.path.join(self.cache_path, file_name + file_ext)
|
||||
try:
|
||||
|
@ -670,11 +676,11 @@ class F0Dataset:
|
|||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
f0 = self.compute_or_load(item["audio_file"])
|
||||
f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
|
||||
if self.normalize_f0:
|
||||
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||
f0 = self.normalize(f0)
|
||||
return {"audio_file": item["audio_file"], "f0": f0}
|
||||
return {"audio_unique_name": item["audio_unique_name"], "f0": f0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
@ -706,8 +712,7 @@ class F0Dataset:
|
|||
return self.pad_id
|
||||
|
||||
@staticmethod
|
||||
def create_pitch_file_path(wav_file, cache_path):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
def create_pitch_file_path(file_name, cache_path):
|
||||
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
|
||||
return pitch_file
|
||||
|
||||
|
@ -745,11 +750,11 @@ class F0Dataset:
|
|||
pitch[zero_idxs] = 0.0
|
||||
return pitch
|
||||
|
||||
def compute_or_load(self, wav_file):
|
||||
def compute_or_load(self, wav_file, audio_unique_name):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||
pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
|
||||
else:
|
||||
|
@ -757,14 +762,14 @@ class F0Dataset:
|
|||
return pitch.astype(np.float32)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
audio_file = [item["audio_file"] for item in batch]
|
||||
audio_unique_name = [item["audio_unique_name"] for item in batch]
|
||||
f0s = [item["f0"] for item in batch]
|
||||
f0_lens = [len(item["f0"]) for item in batch]
|
||||
f0_lens_max = max(f0_lens)
|
||||
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
|
||||
for i, f0_len in enumerate(f0_lens):
|
||||
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
|
||||
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
|
||||
return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens}
|
||||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue