mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #66 from idiap/skip-broken-audio
Skip audio files that can't be decoded
This commit is contained in:
commit
19fce2c87c
|
@ -3,9 +3,10 @@ import collections
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
|
@ -32,29 +33,34 @@ def _parse_sample(item):
|
|||
elif len(item) == 3:
|
||||
text, wav_file, speaker_name = item
|
||||
else:
|
||||
raise ValueError(" [!] Dataset cannot parse the sample.")
|
||||
msg = "Dataset cannot parse the sample."
|
||||
raise ValueError(msg)
|
||||
return text, wav_file, speaker_name, language_name, attn_file
|
||||
|
||||
|
||||
def noise_augment_audio(wav):
|
||||
def noise_augment_audio(wav: npt.NDArray) -> npt.NDArray:
|
||||
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
||||
|
||||
|
||||
def string2filename(string):
|
||||
def string2filename(string: str) -> str:
|
||||
# generate a safe and reversible filename based on a string
|
||||
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
|
||||
return filename
|
||||
return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
|
||||
|
||||
|
||||
def get_audio_size(audiopath) -> int:
|
||||
def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
|
||||
"""Return the number of samples in the audio file."""
|
||||
if not isinstance(audiopath, str):
|
||||
audiopath = str(audiopath)
|
||||
extension = audiopath.rpartition(".")[-1].lower()
|
||||
if extension not in {"mp3", "wav", "flac"}:
|
||||
raise RuntimeError(
|
||||
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
|
||||
)
|
||||
msg = f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return torchaudio.info(audiopath).num_frames
|
||||
try:
|
||||
return torchaudio.info(audiopath).num_frames
|
||||
except RuntimeError as e:
|
||||
msg = f"Failed to decode {audiopath}"
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
|
@ -63,31 +69,32 @@ class TTSDataset(Dataset):
|
|||
outputs_per_step: int = 1,
|
||||
compute_linear_spec: bool = False,
|
||||
ap: AudioProcessor = None,
|
||||
samples: List[Dict] = None,
|
||||
samples: Optional[list[dict]] = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
compute_f0: bool = False,
|
||||
compute_energy: bool = False,
|
||||
f0_cache_path: str = None,
|
||||
energy_cache_path: str = None,
|
||||
f0_cache_path: Optional[str] = None,
|
||||
energy_cache_path: Optional[str] = None,
|
||||
return_wav: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_text_len: int = 0,
|
||||
max_text_len: int = float("inf"),
|
||||
min_audio_len: int = 0,
|
||||
max_audio_len: int = float("inf"),
|
||||
phoneme_cache_path: str = None,
|
||||
phoneme_cache_path: Optional[str] = None,
|
||||
precompute_num_workers: int = 0,
|
||||
speaker_id_mapping: Dict = None,
|
||||
d_vector_mapping: Dict = None,
|
||||
language_id_mapping: Dict = None,
|
||||
speaker_id_mapping: Optional[dict] = None,
|
||||
d_vector_mapping: Optional[dict] = None,
|
||||
language_id_mapping: Optional[dict] = None,
|
||||
use_noise_augment: bool = False,
|
||||
start_by_longest: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||
|
||||
If you need something different, you can subclass and override.
|
||||
|
||||
Args:
|
||||
----
|
||||
outputs_per_step (int): Number of time frames predicted per step.
|
||||
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
|
@ -139,6 +146,7 @@ class TTSDataset(Dataset):
|
|||
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||
|
||||
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.batch_group_size = batch_group_size
|
||||
|
@ -168,25 +176,38 @@ class TTSDataset(Dataset):
|
|||
|
||||
if self.tokenizer.use_phonemes:
|
||||
self.phoneme_dataset = PhonemeDataset(
|
||||
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
|
||||
self.samples,
|
||||
self.tokenizer,
|
||||
phoneme_cache_path,
|
||||
precompute_num_workers=precompute_num_workers,
|
||||
)
|
||||
|
||||
if compute_f0:
|
||||
self.f0_dataset = F0Dataset(
|
||||
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||
self.samples,
|
||||
self.ap,
|
||||
cache_path=f0_cache_path,
|
||||
precompute_num_workers=precompute_num_workers,
|
||||
)
|
||||
if compute_energy:
|
||||
self.energy_dataset = EnergyDataset(
|
||||
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
|
||||
self.samples,
|
||||
self.ap,
|
||||
cache_path=energy_cache_path,
|
||||
precompute_num_workers=precompute_num_workers,
|
||||
)
|
||||
self.print_logs()
|
||||
|
||||
@property
|
||||
def lengths(self):
|
||||
def lengths(self) -> list[int]:
|
||||
lens = []
|
||||
for item in self.samples:
|
||||
_, wav_file, *_ = _parse_sample(item)
|
||||
audio_len = get_audio_size(wav_file)
|
||||
try:
|
||||
audio_len = get_audio_size(wav_file)
|
||||
except RuntimeError:
|
||||
logger.warning(f"Failed to compute length for {item['audio_file']}")
|
||||
audio_len = 0
|
||||
lens.append(audio_len)
|
||||
return lens
|
||||
|
||||
|
@ -195,7 +216,7 @@ class TTSDataset(Dataset):
|
|||
return self._samples
|
||||
|
||||
@samples.setter
|
||||
def samples(self, new_samples):
|
||||
def samples(self, new_samples) -> None:
|
||||
self._samples = new_samples
|
||||
if hasattr(self, "f0_dataset"):
|
||||
self.f0_dataset.samples = new_samples
|
||||
|
@ -204,7 +225,7 @@ class TTSDataset(Dataset):
|
|||
if hasattr(self, "phoneme_dataset"):
|
||||
self.phoneme_dataset.samples = new_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -251,7 +272,7 @@ class TTSDataset(Dataset):
|
|||
token_ids = self.tokenizer.text_to_ids(text)
|
||||
return np.array(token_ids, dtype=np.int32)
|
||||
|
||||
def load_data(self, idx):
|
||||
def load_data(self, idx) -> dict[str, Any]:
|
||||
item = self.samples[idx]
|
||||
|
||||
raw_text = item["text"]
|
||||
|
@ -285,7 +306,7 @@ class TTSDataset(Dataset):
|
|||
if self.compute_energy:
|
||||
energy = self.get_energy(idx)["energy"]
|
||||
|
||||
sample = {
|
||||
return {
|
||||
"raw_text": raw_text,
|
||||
"token_ids": token_ids,
|
||||
"wav": wav,
|
||||
|
@ -298,13 +319,16 @@ class TTSDataset(Dataset):
|
|||
"wav_file_name": os.path.basename(item["audio_file"]),
|
||||
"audio_unique_name": item["audio_unique_name"],
|
||||
}
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def _compute_lengths(samples):
|
||||
new_samples = []
|
||||
for item in samples:
|
||||
audio_length = get_audio_size(item["audio_file"])
|
||||
try:
|
||||
audio_length = get_audio_size(item["audio_file"])
|
||||
except RuntimeError:
|
||||
logger.warning(f"Failed to compute length, skipping {item['audio_file']}")
|
||||
continue
|
||||
text_lenght = len(item["text"])
|
||||
item["audio_length"] = audio_length
|
||||
item["text_length"] = text_lenght
|
||||
|
@ -312,7 +336,7 @@ class TTSDataset(Dataset):
|
|||
return new_samples
|
||||
|
||||
@staticmethod
|
||||
def filter_by_length(lengths: List[int], min_len: int, max_len: int):
|
||||
def filter_by_length(lengths: list[int], min_len: int, max_len: int):
|
||||
idxs = np.argsort(lengths) # ascending order
|
||||
ignore_idx = []
|
||||
keep_idx = []
|
||||
|
@ -325,10 +349,9 @@ class TTSDataset(Dataset):
|
|||
return ignore_idx, keep_idx
|
||||
|
||||
@staticmethod
|
||||
def sort_by_length(samples: List[List]):
|
||||
def sort_by_length(samples: list[list]):
|
||||
audio_lengths = [s["audio_length"] for s in samples]
|
||||
idxs = np.argsort(audio_lengths) # ascending order
|
||||
return idxs
|
||||
return np.argsort(audio_lengths) # ascending order
|
||||
|
||||
@staticmethod
|
||||
def create_buckets(samples, batch_group_size: int):
|
||||
|
@ -348,7 +371,7 @@ class TTSDataset(Dataset):
|
|||
samples_new.append(samples[idx])
|
||||
return samples_new
|
||||
|
||||
def preprocess_samples(self):
|
||||
def preprocess_samples(self) -> None:
|
||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||
range.
|
||||
"""
|
||||
|
@ -374,7 +397,8 @@ class TTSDataset(Dataset):
|
|||
samples = self._select_samples_by_idx(sorted_idxs, samples)
|
||||
|
||||
if len(samples) == 0:
|
||||
raise RuntimeError(" [!] No samples left")
|
||||
msg = "No samples left."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# shuffle batch groups
|
||||
# create batches with similar length items
|
||||
|
@ -388,36 +412,37 @@ class TTSDataset(Dataset):
|
|||
self.samples = samples
|
||||
|
||||
logger.info("Preprocessing samples")
|
||||
logger.info("Max text length: {}".format(np.max(text_lengths)))
|
||||
logger.info("Min text length: {}".format(np.min(text_lengths)))
|
||||
logger.info("Avg text length: {}".format(np.mean(text_lengths)))
|
||||
logger.info("Max audio length: {}".format(np.max(audio_lengths)))
|
||||
logger.info("Min audio length: {}".format(np.min(audio_lengths)))
|
||||
logger.info("Avg audio length: {}".format(np.mean(audio_lengths)))
|
||||
logger.info(f"Max text length: {np.max(text_lengths)}")
|
||||
logger.info(f"Min text length: {np.min(text_lengths)}")
|
||||
logger.info(f"Avg text length: {np.mean(text_lengths)}")
|
||||
logger.info(f"Max audio length: {np.max(audio_lengths)}")
|
||||
logger.info(f"Min audio length: {np.min(audio_lengths)}")
|
||||
logger.info(f"Avg audio length: {np.mean(audio_lengths)}")
|
||||
logger.info("Num. instances discarded samples: %d", len(ignore_idx))
|
||||
logger.info("Batch group size: {}.".format(self.batch_group_size))
|
||||
logger.info(f"Batch group size: {self.batch_group_size}.")
|
||||
|
||||
@staticmethod
|
||||
def _sort_batch(batch, text_lengths):
|
||||
"""Sort the batch by the input text length for RNN efficiency.
|
||||
|
||||
Args:
|
||||
----
|
||||
batch (Dict): Batch returned by `__getitem__`.
|
||||
text_lengths (List[int]): Lengths of the input character sequences.
|
||||
|
||||
"""
|
||||
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
|
||||
batch = [batch[idx] for idx in ids_sorted_decreasing]
|
||||
return batch, text_lengths, ids_sorted_decreasing
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
"""Perform preprocessing and create a final data batch.
|
||||
|
||||
1. Sort batch instances by text-length
|
||||
2. Convert Audio signal to features.
|
||||
3. PAD sequences wrt r.
|
||||
4. Load to Torch.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.abc.Mapping):
|
||||
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
|
||||
|
@ -562,23 +587,18 @@ class TTSDataset(Dataset):
|
|||
"audio_unique_names": batch["audio_unique_name"],
|
||||
}
|
||||
|
||||
raise TypeError(
|
||||
(
|
||||
"batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(
|
||||
type(batch[0])
|
||||
)
|
||||
)
|
||||
)
|
||||
msg = f"batch must contain tensors, numbers, dicts or lists; found {type(batch[0])}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
class PhonemeDataset(Dataset):
|
||||
"""Phoneme Dataset for converting input text to phonemes and then token IDs
|
||||
"""Phoneme Dataset for converting input text to phonemes and then token IDs.
|
||||
|
||||
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
|
||||
loading latency. If `cache_path` is already present, it skips the pre-computation.
|
||||
|
||||
Args:
|
||||
----
|
||||
samples (Union[List[List], List[Dict]]):
|
||||
List of samples. Each sample is a list or a dict.
|
||||
|
||||
|
@ -590,15 +610,16 @@ class PhonemeDataset(Dataset):
|
|||
|
||||
precompute_num_workers (int):
|
||||
Number of workers used for pre-computing the phonemes. Defaults to 0.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
samples: Union[list[dict], list[list]],
|
||||
tokenizer: "TTSTokenizer",
|
||||
cache_path: str,
|
||||
precompute_num_workers=0,
|
||||
):
|
||||
precompute_num_workers: int = 0,
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
self.tokenizer = tokenizer
|
||||
self.cache_path = cache_path
|
||||
|
@ -606,16 +627,16 @@ class PhonemeDataset(Dataset):
|
|||
os.makedirs(cache_path)
|
||||
self.precompute(precompute_num_workers)
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index) -> dict[str, Any]:
|
||||
item = self.samples[index]
|
||||
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"])
|
||||
ph_hat = self.tokenizer.ids_to_text(ids)
|
||||
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def compute_or_load(self, file_name, text, language):
|
||||
def compute_or_load(self, file_name: str, text: str, language: str) -> list[int]:
|
||||
"""Compute phonemes for the given text.
|
||||
|
||||
If the phonemes are already cached, load them from cache.
|
||||
|
@ -629,11 +650,11 @@ class PhonemeDataset(Dataset):
|
|||
np.save(cache_path, ids)
|
||||
return ids
|
||||
|
||||
def get_pad_id(self):
|
||||
"""Get pad token ID for sequence padding"""
|
||||
def get_pad_id(self) -> int:
|
||||
"""Get pad token ID for sequence padding."""
|
||||
return self.tokenizer.pad_id
|
||||
|
||||
def precompute(self, num_workers=1):
|
||||
def precompute(self, num_workers: int = 1) -> None:
|
||||
"""Precompute phonemes for all samples.
|
||||
|
||||
We use pytorch dataloader because we are lazy.
|
||||
|
@ -642,7 +663,11 @@ class PhonemeDataset(Dataset):
|
|||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||
batch_size=batch_size,
|
||||
dataset=self,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
for _ in dataloder:
|
||||
pbar.update(batch_size)
|
||||
|
@ -667,12 +692,13 @@ class PhonemeDataset(Dataset):
|
|||
|
||||
|
||||
class F0Dataset:
|
||||
"""F0 Dataset for computing F0 from wav files in CPU
|
||||
"""F0 Dataset for computing F0 from wav files in CPU.
|
||||
|
||||
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
|
||||
also computes the mean and std of F0 values if `normalize_f0` is True.
|
||||
|
||||
Args:
|
||||
----
|
||||
samples (Union[List[List], List[Dict]]):
|
||||
List of samples. Each sample is a list or a dict.
|
||||
|
||||
|
@ -688,17 +714,18 @@ class F0Dataset:
|
|||
|
||||
normalize_f0 (bool):
|
||||
Whether to normalize F0 values by mean and std. Defaults to True.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: Union[List[List], List[Dict]],
|
||||
samples: Union[list[list], list[dict]],
|
||||
ap: "AudioProcessor",
|
||||
audio_config=None, # pylint: disable=unused-argument
|
||||
cache_path: str = None,
|
||||
precompute_num_workers=0,
|
||||
normalize_f0=True,
|
||||
):
|
||||
cache_path: Optional[str] = None,
|
||||
precompute_num_workers: int = 0,
|
||||
normalize_f0: bool = True,
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
self.ap = ap
|
||||
self.cache_path = cache_path
|
||||
|
@ -720,10 +747,10 @@ class F0Dataset:
|
|||
f0 = self.normalize(f0)
|
||||
return {"audio_unique_name": item["audio_unique_name"], "f0": f0}
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
def precompute(self, num_workers: int = 0) -> None:
|
||||
logger.info("Pre-computing F0s...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
|
@ -731,7 +758,11 @@ class F0Dataset:
|
|||
normalize_f0 = self.normalize_f0
|
||||
self.normalize_f0 = False
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||
batch_size=batch_size,
|
||||
dataset=self,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
computed_data = []
|
||||
for batch in dataloder:
|
||||
|
@ -750,9 +781,8 @@ class F0Dataset:
|
|||
return self.pad_id
|
||||
|
||||
@staticmethod
|
||||
def create_pitch_file_path(file_name, cache_path):
|
||||
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
|
||||
return pitch_file
|
||||
def create_pitch_file_path(file_name: str, cache_path: str) -> str:
|
||||
return os.path.join(cache_path, file_name + "_pitch.npy")
|
||||
|
||||
@staticmethod
|
||||
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
|
||||
|
@ -768,7 +798,7 @@ class F0Dataset:
|
|||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||
return mean, std
|
||||
|
||||
def load_stats(self, cache_path):
|
||||
def load_stats(self, cache_path) -> None:
|
||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"].astype(np.float32)
|
||||
|
@ -789,9 +819,7 @@ class F0Dataset:
|
|||
return pitch
|
||||
|
||||
def compute_or_load(self, wav_file, audio_unique_name):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
"""Compute pitch and return a numpy array of pitch values."""
|
||||
pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
|
||||
|
@ -816,12 +844,13 @@ class F0Dataset:
|
|||
|
||||
|
||||
class EnergyDataset:
|
||||
"""Energy Dataset for computing Energy from wav files in CPU
|
||||
"""Energy Dataset for computing Energy from wav files in CPU.
|
||||
|
||||
Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
|
||||
also computes the mean and std of Energy values if `normalize_Energy` is True.
|
||||
|
||||
Args:
|
||||
----
|
||||
samples (Union[List[List], List[Dict]]):
|
||||
List of samples. Each sample is a list or a dict.
|
||||
|
||||
|
@ -837,16 +866,17 @@ class EnergyDataset:
|
|||
|
||||
normalize_Energy (bool):
|
||||
Whether to normalize Energy values by mean and std. Defaults to True.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: Union[List[List], List[Dict]],
|
||||
samples: Union[list[list], list[dict]],
|
||||
ap: "AudioProcessor",
|
||||
cache_path: str = None,
|
||||
cache_path: Optional[str] = None,
|
||||
precompute_num_workers=0,
|
||||
normalize_energy=True,
|
||||
):
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
self.ap = ap
|
||||
self.cache_path = cache_path
|
||||
|
@ -868,10 +898,10 @@ class EnergyDataset:
|
|||
energy = self.normalize(energy)
|
||||
return {"audio_unique_name": item["audio_unique_name"], "energy": energy}
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
def precompute(self, num_workers=0) -> None:
|
||||
logger.info("Pre-computing energys...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
|
@ -879,7 +909,11 @@ class EnergyDataset:
|
|||
normalize_energy = self.normalize_energy
|
||||
self.normalize_energy = False
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||
batch_size=batch_size,
|
||||
dataset=self,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
computed_data = []
|
||||
for batch in dataloder:
|
||||
|
@ -900,8 +934,7 @@ class EnergyDataset:
|
|||
@staticmethod
|
||||
def create_energy_file_path(wav_file, cache_path):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
energy_file = os.path.join(cache_path, file_name + "_energy.npy")
|
||||
return energy_file
|
||||
return os.path.join(cache_path, file_name + "_energy.npy")
|
||||
|
||||
@staticmethod
|
||||
def _compute_and_save_energy(ap, wav_file, energy_file=None):
|
||||
|
@ -917,7 +950,7 @@ class EnergyDataset:
|
|||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||
return mean, std
|
||||
|
||||
def load_stats(self, cache_path):
|
||||
def load_stats(self, cache_path) -> None:
|
||||
stats_path = os.path.join(cache_path, "energy_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"].astype(np.float32)
|
||||
|
@ -938,9 +971,7 @@ class EnergyDataset:
|
|||
return energy
|
||||
|
||||
def compute_or_load(self, wav_file, audio_unique_name):
|
||||
"""
|
||||
compute energy and return a numpy array of energy values
|
||||
"""
|
||||
"""Compute energy and return a numpy array of energy values."""
|
||||
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
|
||||
if not os.path.exists(energy_file):
|
||||
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
|
||||
|
|
Loading…
Reference in New Issue