chore(dataset): address lint issues

This commit is contained in:
Enno Hermann 2024-07-31 15:40:46 +02:00
parent 8c460d0cd0
commit 9c604c1de0
1 changed files with 109 additions and 92 deletions

View File

@ -3,9 +3,10 @@ import collections
import logging import logging
import os import os
import random import random
from typing import Any, Dict, List, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import numpy.typing as npt
import torch import torch
import torchaudio import torchaudio
import tqdm import tqdm
@ -32,18 +33,18 @@ def _parse_sample(item):
elif len(item) == 3: elif len(item) == 3:
text, wav_file, speaker_name = item text, wav_file, speaker_name = item
else: else:
raise ValueError(" [!] Dataset cannot parse the sample.") msg = "Dataset cannot parse the sample."
raise ValueError(msg)
return text, wav_file, speaker_name, language_name, attn_file return text, wav_file, speaker_name, language_name, attn_file
def noise_augment_audio(wav): def noise_augment_audio(wav: npt.NDArray) -> npt.NDArray:
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
def string2filename(string): def string2filename(string: str) -> str:
# generate a safe and reversible filename based on a string # generate a safe and reversible filename based on a string
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
return filename
def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
@ -52,9 +53,8 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
audiopath = str(audiopath) audiopath = str(audiopath)
extension = audiopath.rpartition(".")[-1].lower() extension = audiopath.rpartition(".")[-1].lower()
if extension not in {"mp3", "wav", "flac"}: if extension not in {"mp3", "wav", "flac"}:
raise RuntimeError( msg = f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" raise RuntimeError(msg)
)
try: try:
return torchaudio.info(audiopath).num_frames return torchaudio.info(audiopath).num_frames
@ -69,31 +69,32 @@ class TTSDataset(Dataset):
outputs_per_step: int = 1, outputs_per_step: int = 1,
compute_linear_spec: bool = False, compute_linear_spec: bool = False,
ap: AudioProcessor = None, ap: AudioProcessor = None,
samples: List[Dict] = None, samples: Optional[list[dict]] = None,
tokenizer: "TTSTokenizer" = None, tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False, compute_f0: bool = False,
compute_energy: bool = False, compute_energy: bool = False,
f0_cache_path: str = None, f0_cache_path: Optional[str] = None,
energy_cache_path: str = None, energy_cache_path: Optional[str] = None,
return_wav: bool = False, return_wav: bool = False,
batch_group_size: int = 0, batch_group_size: int = 0,
min_text_len: int = 0, min_text_len: int = 0,
max_text_len: int = float("inf"), max_text_len: int = float("inf"),
min_audio_len: int = 0, min_audio_len: int = 0,
max_audio_len: int = float("inf"), max_audio_len: int = float("inf"),
phoneme_cache_path: str = None, phoneme_cache_path: Optional[str] = None,
precompute_num_workers: int = 0, precompute_num_workers: int = 0,
speaker_id_mapping: Dict = None, speaker_id_mapping: Optional[dict] = None,
d_vector_mapping: Dict = None, d_vector_mapping: Optional[dict] = None,
language_id_mapping: Dict = None, language_id_mapping: Optional[dict] = None,
use_noise_augment: bool = False, use_noise_augment: bool = False,
start_by_longest: bool = False, start_by_longest: bool = False,
): ) -> None:
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can subclass and override. If you need something different, you can subclass and override.
Args: Args:
----
outputs_per_step (int): Number of time frames predicted per step. outputs_per_step (int): Number of time frames predicted per step.
compute_linear_spec (bool): compute linear spectrogram if True. compute_linear_spec (bool): compute linear spectrogram if True.
@ -145,6 +146,7 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
""" """
super().__init__() super().__init__()
self.batch_group_size = batch_group_size self.batch_group_size = batch_group_size
@ -174,28 +176,37 @@ class TTSDataset(Dataset):
if self.tokenizer.use_phonemes: if self.tokenizer.use_phonemes:
self.phoneme_dataset = PhonemeDataset( self.phoneme_dataset = PhonemeDataset(
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers self.samples,
self.tokenizer,
phoneme_cache_path,
precompute_num_workers=precompute_num_workers,
) )
if compute_f0: if compute_f0:
self.f0_dataset = F0Dataset( self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers self.samples,
self.ap,
cache_path=f0_cache_path,
precompute_num_workers=precompute_num_workers,
) )
if compute_energy: if compute_energy:
self.energy_dataset = EnergyDataset( self.energy_dataset = EnergyDataset(
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers self.samples,
self.ap,
cache_path=energy_cache_path,
precompute_num_workers=precompute_num_workers,
) )
self.print_logs() self.print_logs()
@property @property
def lengths(self): def lengths(self) -> list[int]:
lens = [] lens = []
for item in self.samples: for item in self.samples:
_, wav_file, *_ = _parse_sample(item) _, wav_file, *_ = _parse_sample(item)
try: try:
audio_len = get_audio_size(wav_file) audio_len = get_audio_size(wav_file)
except RuntimeError: except RuntimeError:
logger.warn(f"Failed to compute length for {item['audio_file']}") logger.warning(f"Failed to compute length for {item['audio_file']}")
audio_len = 0 audio_len = 0
lens.append(audio_len) lens.append(audio_len)
return lens return lens
@ -205,7 +216,7 @@ class TTSDataset(Dataset):
return self._samples return self._samples
@samples.setter @samples.setter
def samples(self, new_samples): def samples(self, new_samples) -> None:
self._samples = new_samples self._samples = new_samples
if hasattr(self, "f0_dataset"): if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples self.f0_dataset.samples = new_samples
@ -214,7 +225,7 @@ class TTSDataset(Dataset):
if hasattr(self, "phoneme_dataset"): if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples self.phoneme_dataset.samples = new_samples
def __len__(self): def __len__(self) -> int:
return len(self.samples) return len(self.samples)
def __getitem__(self, idx): def __getitem__(self, idx):
@ -261,7 +272,7 @@ class TTSDataset(Dataset):
token_ids = self.tokenizer.text_to_ids(text) token_ids = self.tokenizer.text_to_ids(text)
return np.array(token_ids, dtype=np.int32) return np.array(token_ids, dtype=np.int32)
def load_data(self, idx): def load_data(self, idx) -> dict[str, Any]:
item = self.samples[idx] item = self.samples[idx]
raw_text = item["text"] raw_text = item["text"]
@ -295,7 +306,7 @@ class TTSDataset(Dataset):
if self.compute_energy: if self.compute_energy:
energy = self.get_energy(idx)["energy"] energy = self.get_energy(idx)["energy"]
sample = { return {
"raw_text": raw_text, "raw_text": raw_text,
"token_ids": token_ids, "token_ids": token_ids,
"wav": wav, "wav": wav,
@ -308,7 +319,6 @@ class TTSDataset(Dataset):
"wav_file_name": os.path.basename(item["audio_file"]), "wav_file_name": os.path.basename(item["audio_file"]),
"audio_unique_name": item["audio_unique_name"], "audio_unique_name": item["audio_unique_name"],
} }
return sample
@staticmethod @staticmethod
def _compute_lengths(samples): def _compute_lengths(samples):
@ -317,7 +327,7 @@ class TTSDataset(Dataset):
try: try:
audio_length = get_audio_size(item["audio_file"]) audio_length = get_audio_size(item["audio_file"])
except RuntimeError: except RuntimeError:
logger.warn(f"Failed to compute length, skipping {item['audio_file']}") logger.warning(f"Failed to compute length, skipping {item['audio_file']}")
continue continue
text_lenght = len(item["text"]) text_lenght = len(item["text"])
item["audio_length"] = audio_length item["audio_length"] = audio_length
@ -326,7 +336,7 @@ class TTSDataset(Dataset):
return new_samples return new_samples
@staticmethod @staticmethod
def filter_by_length(lengths: List[int], min_len: int, max_len: int): def filter_by_length(lengths: list[int], min_len: int, max_len: int):
idxs = np.argsort(lengths) # ascending order idxs = np.argsort(lengths) # ascending order
ignore_idx = [] ignore_idx = []
keep_idx = [] keep_idx = []
@ -339,10 +349,9 @@ class TTSDataset(Dataset):
return ignore_idx, keep_idx return ignore_idx, keep_idx
@staticmethod @staticmethod
def sort_by_length(samples: List[List]): def sort_by_length(samples: list[list]):
audio_lengths = [s["audio_length"] for s in samples] audio_lengths = [s["audio_length"] for s in samples]
idxs = np.argsort(audio_lengths) # ascending order return np.argsort(audio_lengths) # ascending order
return idxs
@staticmethod @staticmethod
def create_buckets(samples, batch_group_size: int): def create_buckets(samples, batch_group_size: int):
@ -362,7 +371,7 @@ class TTSDataset(Dataset):
samples_new.append(samples[idx]) samples_new.append(samples[idx])
return samples_new return samples_new
def preprocess_samples(self): def preprocess_samples(self) -> None:
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range. range.
""" """
@ -388,7 +397,8 @@ class TTSDataset(Dataset):
samples = self._select_samples_by_idx(sorted_idxs, samples) samples = self._select_samples_by_idx(sorted_idxs, samples)
if len(samples) == 0: if len(samples) == 0:
raise RuntimeError(" [!] No samples left") msg = "No samples left."
raise RuntimeError(msg)
# shuffle batch groups # shuffle batch groups
# create batches with similar length items # create batches with similar length items
@ -402,36 +412,37 @@ class TTSDataset(Dataset):
self.samples = samples self.samples = samples
logger.info("Preprocessing samples") logger.info("Preprocessing samples")
logger.info("Max text length: {}".format(np.max(text_lengths))) logger.info(f"Max text length: {np.max(text_lengths)}")
logger.info("Min text length: {}".format(np.min(text_lengths))) logger.info(f"Min text length: {np.min(text_lengths)}")
logger.info("Avg text length: {}".format(np.mean(text_lengths))) logger.info(f"Avg text length: {np.mean(text_lengths)}")
logger.info("Max audio length: {}".format(np.max(audio_lengths))) logger.info(f"Max audio length: {np.max(audio_lengths)}")
logger.info("Min audio length: {}".format(np.min(audio_lengths))) logger.info(f"Min audio length: {np.min(audio_lengths)}")
logger.info("Avg audio length: {}".format(np.mean(audio_lengths))) logger.info(f"Avg audio length: {np.mean(audio_lengths)}")
logger.info("Num. instances discarded samples: %d", len(ignore_idx)) logger.info("Num. instances discarded samples: %d", len(ignore_idx))
logger.info("Batch group size: {}.".format(self.batch_group_size)) logger.info(f"Batch group size: {self.batch_group_size}.")
@staticmethod @staticmethod
def _sort_batch(batch, text_lengths): def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency. """Sort the batch by the input text length for RNN efficiency.
Args: Args:
----
batch (Dict): Batch returned by `__getitem__`. batch (Dict): Batch returned by `__getitem__`.
text_lengths (List[int]): Lengths of the input character sequences. text_lengths (List[int]): Lengths of the input character sequences.
""" """
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
batch = [batch[idx] for idx in ids_sorted_decreasing] batch = [batch[idx] for idx in ids_sorted_decreasing]
return batch, text_lengths, ids_sorted_decreasing return batch, text_lengths, ids_sorted_decreasing
def collate_fn(self, batch): def collate_fn(self, batch):
r""" """Perform preprocessing and create a final data batch.
Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length 1. Sort batch instances by text-length
2. Convert Audio signal to features. 2. Convert Audio signal to features.
3. PAD sequences wrt r. 3. PAD sequences wrt r.
4. Load to Torch. 4. Load to Torch.
""" """
# Puts each data field into a tensor with outer dimension batch size # Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping): if isinstance(batch[0], collections.abc.Mapping):
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
@ -576,23 +587,18 @@ class TTSDataset(Dataset):
"audio_unique_names": batch["audio_unique_name"], "audio_unique_names": batch["audio_unique_name"],
} }
raise TypeError( msg = f"batch must contain tensors, numbers, dicts or lists; found {type(batch[0])}"
( raise TypeError(msg)
"batch must contain tensors, numbers, dicts or lists;\
found {}".format(
type(batch[0])
)
)
)
class PhonemeDataset(Dataset): class PhonemeDataset(Dataset):
"""Phoneme Dataset for converting input text to phonemes and then token IDs """Phoneme Dataset for converting input text to phonemes and then token IDs.
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
loading latency. If `cache_path` is already present, it skips the pre-computation. loading latency. If `cache_path` is already present, it skips the pre-computation.
Args: Args:
----
samples (Union[List[List], List[Dict]]): samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict. List of samples. Each sample is a list or a dict.
@ -604,15 +610,16 @@ class PhonemeDataset(Dataset):
precompute_num_workers (int): precompute_num_workers (int):
Number of workers used for pre-computing the phonemes. Defaults to 0. Number of workers used for pre-computing the phonemes. Defaults to 0.
""" """
def __init__( def __init__(
self, self,
samples: Union[List[Dict], List[List]], samples: Union[list[dict], list[list]],
tokenizer: "TTSTokenizer", tokenizer: "TTSTokenizer",
cache_path: str, cache_path: str,
precompute_num_workers=0, precompute_num_workers: int = 0,
): ) -> None:
self.samples = samples self.samples = samples
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.cache_path = cache_path self.cache_path = cache_path
@ -620,16 +627,16 @@ class PhonemeDataset(Dataset):
os.makedirs(cache_path) os.makedirs(cache_path)
self.precompute(precompute_num_workers) self.precompute(precompute_num_workers)
def __getitem__(self, index): def __getitem__(self, index) -> dict[str, Any]:
item = self.samples[index] item = self.samples[index]
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"]) ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"])
ph_hat = self.tokenizer.ids_to_text(ids) ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self): def __len__(self) -> int:
return len(self.samples) return len(self.samples)
def compute_or_load(self, file_name, text, language): def compute_or_load(self, file_name: str, text: str, language: str) -> list[int]:
"""Compute phonemes for the given text. """Compute phonemes for the given text.
If the phonemes are already cached, load them from cache. If the phonemes are already cached, load them from cache.
@ -643,11 +650,11 @@ class PhonemeDataset(Dataset):
np.save(cache_path, ids) np.save(cache_path, ids)
return ids return ids
def get_pad_id(self): def get_pad_id(self) -> int:
"""Get pad token ID for sequence padding""" """Get pad token ID for sequence padding."""
return self.tokenizer.pad_id return self.tokenizer.pad_id
def precompute(self, num_workers=1): def precompute(self, num_workers: int = 1) -> None:
"""Precompute phonemes for all samples. """Precompute phonemes for all samples.
We use pytorch dataloader because we are lazy. We use pytorch dataloader because we are lazy.
@ -656,7 +663,11 @@ class PhonemeDataset(Dataset):
with tqdm.tqdm(total=len(self)) as pbar: with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1 batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader( dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn batch_size=batch_size,
dataset=self,
shuffle=False,
num_workers=num_workers,
collate_fn=self.collate_fn,
) )
for _ in dataloder: for _ in dataloder:
pbar.update(batch_size) pbar.update(batch_size)
@ -681,12 +692,13 @@ class PhonemeDataset(Dataset):
class F0Dataset: class F0Dataset:
"""F0 Dataset for computing F0 from wav files in CPU """F0 Dataset for computing F0 from wav files in CPU.
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of F0 values if `normalize_f0` is True. also computes the mean and std of F0 values if `normalize_f0` is True.
Args: Args:
----
samples (Union[List[List], List[Dict]]): samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict. List of samples. Each sample is a list or a dict.
@ -702,17 +714,18 @@ class F0Dataset:
normalize_f0 (bool): normalize_f0 (bool):
Whether to normalize F0 values by mean and std. Defaults to True. Whether to normalize F0 values by mean and std. Defaults to True.
""" """
def __init__( def __init__(
self, self,
samples: Union[List[List], List[Dict]], samples: Union[list[list], list[dict]],
ap: "AudioProcessor", ap: "AudioProcessor",
audio_config=None, # pylint: disable=unused-argument audio_config=None, # pylint: disable=unused-argument
cache_path: str = None, cache_path: Optional[str] = None,
precompute_num_workers=0, precompute_num_workers: int = 0,
normalize_f0=True, normalize_f0: bool = True,
): ) -> None:
self.samples = samples self.samples = samples
self.ap = ap self.ap = ap
self.cache_path = cache_path self.cache_path = cache_path
@ -734,10 +747,10 @@ class F0Dataset:
f0 = self.normalize(f0) f0 = self.normalize(f0)
return {"audio_unique_name": item["audio_unique_name"], "f0": f0} return {"audio_unique_name": item["audio_unique_name"], "f0": f0}
def __len__(self): def __len__(self) -> int:
return len(self.samples) return len(self.samples)
def precompute(self, num_workers=0): def precompute(self, num_workers: int = 0) -> None:
logger.info("Pre-computing F0s...") logger.info("Pre-computing F0s...")
with tqdm.tqdm(total=len(self)) as pbar: with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1 batch_size = num_workers if num_workers > 0 else 1
@ -745,7 +758,11 @@ class F0Dataset:
normalize_f0 = self.normalize_f0 normalize_f0 = self.normalize_f0
self.normalize_f0 = False self.normalize_f0 = False
dataloder = torch.utils.data.DataLoader( dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn batch_size=batch_size,
dataset=self,
shuffle=False,
num_workers=num_workers,
collate_fn=self.collate_fn,
) )
computed_data = [] computed_data = []
for batch in dataloder: for batch in dataloder:
@ -764,9 +781,8 @@ class F0Dataset:
return self.pad_id return self.pad_id
@staticmethod @staticmethod
def create_pitch_file_path(file_name, cache_path): def create_pitch_file_path(file_name: str, cache_path: str) -> str:
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") return os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file
@staticmethod @staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None): def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
@ -782,7 +798,7 @@ class F0Dataset:
mean, std = np.mean(nonzeros), np.std(nonzeros) mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std return mean, std
def load_stats(self, cache_path): def load_stats(self, cache_path) -> None:
stats_path = os.path.join(cache_path, "pitch_stats.npy") stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item() stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32) self.mean = stats["mean"].astype(np.float32)
@ -803,9 +819,7 @@ class F0Dataset:
return pitch return pitch
def compute_or_load(self, wav_file, audio_unique_name): def compute_or_load(self, wav_file, audio_unique_name):
""" """Compute pitch and return a numpy array of pitch values."""
compute pitch and return a numpy array of pitch values
"""
pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
@ -830,12 +844,13 @@ class F0Dataset:
class EnergyDataset: class EnergyDataset:
"""Energy Dataset for computing Energy from wav files in CPU """Energy Dataset for computing Energy from wav files in CPU.
Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of Energy values if `normalize_Energy` is True. also computes the mean and std of Energy values if `normalize_Energy` is True.
Args: Args:
----
samples (Union[List[List], List[Dict]]): samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict. List of samples. Each sample is a list or a dict.
@ -851,16 +866,17 @@ class EnergyDataset:
normalize_Energy (bool): normalize_Energy (bool):
Whether to normalize Energy values by mean and std. Defaults to True. Whether to normalize Energy values by mean and std. Defaults to True.
""" """
def __init__( def __init__(
self, self,
samples: Union[List[List], List[Dict]], samples: Union[list[list], list[dict]],
ap: "AudioProcessor", ap: "AudioProcessor",
cache_path: str = None, cache_path: Optional[str] = None,
precompute_num_workers=0, precompute_num_workers=0,
normalize_energy=True, normalize_energy=True,
): ) -> None:
self.samples = samples self.samples = samples
self.ap = ap self.ap = ap
self.cache_path = cache_path self.cache_path = cache_path
@ -882,10 +898,10 @@ class EnergyDataset:
energy = self.normalize(energy) energy = self.normalize(energy)
return {"audio_unique_name": item["audio_unique_name"], "energy": energy} return {"audio_unique_name": item["audio_unique_name"], "energy": energy}
def __len__(self): def __len__(self) -> int:
return len(self.samples) return len(self.samples)
def precompute(self, num_workers=0): def precompute(self, num_workers=0) -> None:
logger.info("Pre-computing energys...") logger.info("Pre-computing energys...")
with tqdm.tqdm(total=len(self)) as pbar: with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1 batch_size = num_workers if num_workers > 0 else 1
@ -893,7 +909,11 @@ class EnergyDataset:
normalize_energy = self.normalize_energy normalize_energy = self.normalize_energy
self.normalize_energy = False self.normalize_energy = False
dataloder = torch.utils.data.DataLoader( dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn batch_size=batch_size,
dataset=self,
shuffle=False,
num_workers=num_workers,
collate_fn=self.collate_fn,
) )
computed_data = [] computed_data = []
for batch in dataloder: for batch in dataloder:
@ -914,8 +934,7 @@ class EnergyDataset:
@staticmethod @staticmethod
def create_energy_file_path(wav_file, cache_path): def create_energy_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0] file_name = os.path.splitext(os.path.basename(wav_file))[0]
energy_file = os.path.join(cache_path, file_name + "_energy.npy") return os.path.join(cache_path, file_name + "_energy.npy")
return energy_file
@staticmethod @staticmethod
def _compute_and_save_energy(ap, wav_file, energy_file=None): def _compute_and_save_energy(ap, wav_file, energy_file=None):
@ -931,7 +950,7 @@ class EnergyDataset:
mean, std = np.mean(nonzeros), np.std(nonzeros) mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std return mean, std
def load_stats(self, cache_path): def load_stats(self, cache_path) -> None:
stats_path = os.path.join(cache_path, "energy_stats.npy") stats_path = os.path.join(cache_path, "energy_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item() stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32) self.mean = stats["mean"].astype(np.float32)
@ -952,9 +971,7 @@ class EnergyDataset:
return energy return energy
def compute_or_load(self, wav_file, audio_unique_name): def compute_or_load(self, wav_file, audio_unique_name):
""" """Compute energy and return a numpy array of energy values."""
compute energy and return a numpy array of energy values
"""
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(energy_file): if not os.path.exists(energy_file):
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file) energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)