Merge pull request #66 from idiap/skip-broken-audio

Skip audio files that can't be decoded
This commit is contained in:
Enno Hermann 2024-07-31 15:40:21 +01:00 committed by GitHub
commit 19fce2c87c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 125 additions and 94 deletions

View File

@ -3,9 +3,10 @@ import collections
import logging
import os
import random
from typing import Dict, List, Union
from typing import Any, Optional, Union
import numpy as np
import numpy.typing as npt
import torch
import torchaudio
import tqdm
@ -32,29 +33,34 @@ def _parse_sample(item):
elif len(item) == 3:
text, wav_file, speaker_name = item
else:
raise ValueError(" [!] Dataset cannot parse the sample.")
msg = "Dataset cannot parse the sample."
raise ValueError(msg)
return text, wav_file, speaker_name, language_name, attn_file
def noise_augment_audio(wav):
def noise_augment_audio(wav: npt.NDArray) -> npt.NDArray:
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
def string2filename(string):
def string2filename(string: str) -> str:
# generate a safe and reversible filename based on a string
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
return filename
return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
def get_audio_size(audiopath) -> int:
def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
"""Return the number of samples in the audio file."""
if not isinstance(audiopath, str):
audiopath = str(audiopath)
extension = audiopath.rpartition(".")[-1].lower()
if extension not in {"mp3", "wav", "flac"}:
raise RuntimeError(
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
)
msg = f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
raise RuntimeError(msg)
return torchaudio.info(audiopath).num_frames
try:
return torchaudio.info(audiopath).num_frames
except RuntimeError as e:
msg = f"Failed to decode {audiopath}"
raise RuntimeError(msg) from e
class TTSDataset(Dataset):
@ -63,31 +69,32 @@ class TTSDataset(Dataset):
outputs_per_step: int = 1,
compute_linear_spec: bool = False,
ap: AudioProcessor = None,
samples: List[Dict] = None,
samples: Optional[list[dict]] = None,
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
compute_energy: bool = False,
f0_cache_path: str = None,
energy_cache_path: str = None,
f0_cache_path: Optional[str] = None,
energy_cache_path: Optional[str] = None,
return_wav: bool = False,
batch_group_size: int = 0,
min_text_len: int = 0,
max_text_len: int = float("inf"),
min_audio_len: int = 0,
max_audio_len: int = float("inf"),
phoneme_cache_path: str = None,
phoneme_cache_path: Optional[str] = None,
precompute_num_workers: int = 0,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
speaker_id_mapping: Optional[dict] = None,
d_vector_mapping: Optional[dict] = None,
language_id_mapping: Optional[dict] = None,
use_noise_augment: bool = False,
start_by_longest: bool = False,
):
) -> None:
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can subclass and override.
Args:
----
outputs_per_step (int): Number of time frames predicted per step.
compute_linear_spec (bool): compute linear spectrogram if True.
@ -139,6 +146,7 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
"""
super().__init__()
self.batch_group_size = batch_group_size
@ -168,25 +176,38 @@ class TTSDataset(Dataset):
if self.tokenizer.use_phonemes:
self.phoneme_dataset = PhonemeDataset(
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
self.samples,
self.tokenizer,
phoneme_cache_path,
precompute_num_workers=precompute_num_workers,
)
if compute_f0:
self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
self.samples,
self.ap,
cache_path=f0_cache_path,
precompute_num_workers=precompute_num_workers,
)
if compute_energy:
self.energy_dataset = EnergyDataset(
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
self.samples,
self.ap,
cache_path=energy_cache_path,
precompute_num_workers=precompute_num_workers,
)
self.print_logs()
@property
def lengths(self):
def lengths(self) -> list[int]:
lens = []
for item in self.samples:
_, wav_file, *_ = _parse_sample(item)
audio_len = get_audio_size(wav_file)
try:
audio_len = get_audio_size(wav_file)
except RuntimeError:
logger.warning(f"Failed to compute length for {item['audio_file']}")
audio_len = 0
lens.append(audio_len)
return lens
@ -195,7 +216,7 @@ class TTSDataset(Dataset):
return self._samples
@samples.setter
def samples(self, new_samples):
def samples(self, new_samples) -> None:
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
@ -204,7 +225,7 @@ class TTSDataset(Dataset):
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples
def __len__(self):
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx):
@ -251,7 +272,7 @@ class TTSDataset(Dataset):
token_ids = self.tokenizer.text_to_ids(text)
return np.array(token_ids, dtype=np.int32)
def load_data(self, idx):
def load_data(self, idx) -> dict[str, Any]:
item = self.samples[idx]
raw_text = item["text"]
@ -285,7 +306,7 @@ class TTSDataset(Dataset):
if self.compute_energy:
energy = self.get_energy(idx)["energy"]
sample = {
return {
"raw_text": raw_text,
"token_ids": token_ids,
"wav": wav,
@ -298,13 +319,16 @@ class TTSDataset(Dataset):
"wav_file_name": os.path.basename(item["audio_file"]),
"audio_unique_name": item["audio_unique_name"],
}
return sample
@staticmethod
def _compute_lengths(samples):
new_samples = []
for item in samples:
audio_length = get_audio_size(item["audio_file"])
try:
audio_length = get_audio_size(item["audio_file"])
except RuntimeError:
logger.warning(f"Failed to compute length, skipping {item['audio_file']}")
continue
text_lenght = len(item["text"])
item["audio_length"] = audio_length
item["text_length"] = text_lenght
@ -312,7 +336,7 @@ class TTSDataset(Dataset):
return new_samples
@staticmethod
def filter_by_length(lengths: List[int], min_len: int, max_len: int):
def filter_by_length(lengths: list[int], min_len: int, max_len: int):
idxs = np.argsort(lengths) # ascending order
ignore_idx = []
keep_idx = []
@ -325,10 +349,9 @@ class TTSDataset(Dataset):
return ignore_idx, keep_idx
@staticmethod
def sort_by_length(samples: List[List]):
def sort_by_length(samples: list[list]):
audio_lengths = [s["audio_length"] for s in samples]
idxs = np.argsort(audio_lengths) # ascending order
return idxs
return np.argsort(audio_lengths) # ascending order
@staticmethod
def create_buckets(samples, batch_group_size: int):
@ -348,7 +371,7 @@ class TTSDataset(Dataset):
samples_new.append(samples[idx])
return samples_new
def preprocess_samples(self):
def preprocess_samples(self) -> None:
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
"""
@ -374,7 +397,8 @@ class TTSDataset(Dataset):
samples = self._select_samples_by_idx(sorted_idxs, samples)
if len(samples) == 0:
raise RuntimeError(" [!] No samples left")
msg = "No samples left."
raise RuntimeError(msg)
# shuffle batch groups
# create batches with similar length items
@ -388,36 +412,37 @@ class TTSDataset(Dataset):
self.samples = samples
logger.info("Preprocessing samples")
logger.info("Max text length: {}".format(np.max(text_lengths)))
logger.info("Min text length: {}".format(np.min(text_lengths)))
logger.info("Avg text length: {}".format(np.mean(text_lengths)))
logger.info("Max audio length: {}".format(np.max(audio_lengths)))
logger.info("Min audio length: {}".format(np.min(audio_lengths)))
logger.info("Avg audio length: {}".format(np.mean(audio_lengths)))
logger.info(f"Max text length: {np.max(text_lengths)}")
logger.info(f"Min text length: {np.min(text_lengths)}")
logger.info(f"Avg text length: {np.mean(text_lengths)}")
logger.info(f"Max audio length: {np.max(audio_lengths)}")
logger.info(f"Min audio length: {np.min(audio_lengths)}")
logger.info(f"Avg audio length: {np.mean(audio_lengths)}")
logger.info("Num. instances discarded samples: %d", len(ignore_idx))
logger.info("Batch group size: {}.".format(self.batch_group_size))
logger.info(f"Batch group size: {self.batch_group_size}.")
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency.
Args:
----
batch (Dict): Batch returned by `__getitem__`.
text_lengths (List[int]): Lengths of the input character sequences.
"""
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
batch = [batch[idx] for idx in ids_sorted_decreasing]
return batch, text_lengths, ids_sorted_decreasing
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
"""Perform preprocessing and create a final data batch.
1. Sort batch instances by text-length
2. Convert Audio signal to features.
3. PAD sequences wrt r.
4. Load to Torch.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping):
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
@ -562,23 +587,18 @@ class TTSDataset(Dataset):
"audio_unique_names": batch["audio_unique_name"],
}
raise TypeError(
(
"batch must contain tensors, numbers, dicts or lists;\
found {}".format(
type(batch[0])
)
)
)
msg = f"batch must contain tensors, numbers, dicts or lists; found {type(batch[0])}"
raise TypeError(msg)
class PhonemeDataset(Dataset):
"""Phoneme Dataset for converting input text to phonemes and then token IDs
"""Phoneme Dataset for converting input text to phonemes and then token IDs.
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
loading latency. If `cache_path` is already present, it skips the pre-computation.
Args:
----
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
@ -590,15 +610,16 @@ class PhonemeDataset(Dataset):
precompute_num_workers (int):
Number of workers used for pre-computing the phonemes. Defaults to 0.
"""
def __init__(
self,
samples: Union[List[Dict], List[List]],
samples: Union[list[dict], list[list]],
tokenizer: "TTSTokenizer",
cache_path: str,
precompute_num_workers=0,
):
precompute_num_workers: int = 0,
) -> None:
self.samples = samples
self.tokenizer = tokenizer
self.cache_path = cache_path
@ -606,16 +627,16 @@ class PhonemeDataset(Dataset):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
def __getitem__(self, index):
def __getitem__(self, index) -> dict[str, Any]:
item = self.samples[index]
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
def __len__(self) -> int:
return len(self.samples)
def compute_or_load(self, file_name, text, language):
def compute_or_load(self, file_name: str, text: str, language: str) -> list[int]:
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
@ -629,11 +650,11 @@ class PhonemeDataset(Dataset):
np.save(cache_path, ids)
return ids
def get_pad_id(self):
"""Get pad token ID for sequence padding"""
def get_pad_id(self) -> int:
"""Get pad token ID for sequence padding."""
return self.tokenizer.pad_id
def precompute(self, num_workers=1):
def precompute(self, num_workers: int = 1) -> None:
"""Precompute phonemes for all samples.
We use pytorch dataloader because we are lazy.
@ -642,7 +663,11 @@ class PhonemeDataset(Dataset):
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
batch_size=batch_size,
dataset=self,
shuffle=False,
num_workers=num_workers,
collate_fn=self.collate_fn,
)
for _ in dataloder:
pbar.update(batch_size)
@ -667,12 +692,13 @@ class PhonemeDataset(Dataset):
class F0Dataset:
"""F0 Dataset for computing F0 from wav files in CPU
"""F0 Dataset for computing F0 from wav files in CPU.
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of F0 values if `normalize_f0` is True.
Args:
----
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
@ -688,17 +714,18 @@ class F0Dataset:
normalize_f0 (bool):
Whether to normalize F0 values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
samples: Union[list[list], list[dict]],
ap: "AudioProcessor",
audio_config=None, # pylint: disable=unused-argument
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
cache_path: Optional[str] = None,
precompute_num_workers: int = 0,
normalize_f0: bool = True,
) -> None:
self.samples = samples
self.ap = ap
self.cache_path = cache_path
@ -720,10 +747,10 @@ class F0Dataset:
f0 = self.normalize(f0)
return {"audio_unique_name": item["audio_unique_name"], "f0": f0}
def __len__(self):
def __len__(self) -> int:
return len(self.samples)
def precompute(self, num_workers=0):
def precompute(self, num_workers: int = 0) -> None:
logger.info("Pre-computing F0s...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
@ -731,7 +758,11 @@ class F0Dataset:
normalize_f0 = self.normalize_f0
self.normalize_f0 = False
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
batch_size=batch_size,
dataset=self,
shuffle=False,
num_workers=num_workers,
collate_fn=self.collate_fn,
)
computed_data = []
for batch in dataloder:
@ -750,9 +781,8 @@ class F0Dataset:
return self.pad_id
@staticmethod
def create_pitch_file_path(file_name, cache_path):
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file
def create_pitch_file_path(file_name: str, cache_path: str) -> str:
return os.path.join(cache_path, file_name + "_pitch.npy")
@staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
@ -768,7 +798,7 @@ class F0Dataset:
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def load_stats(self, cache_path):
def load_stats(self, cache_path) -> None:
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
@ -789,9 +819,7 @@ class F0Dataset:
return pitch
def compute_or_load(self, wav_file, audio_unique_name):
"""
compute pitch and return a numpy array of pitch values
"""
"""Compute pitch and return a numpy array of pitch values."""
pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
@ -816,12 +844,13 @@ class F0Dataset:
class EnergyDataset:
"""Energy Dataset for computing Energy from wav files in CPU
"""Energy Dataset for computing Energy from wav files in CPU.
Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of Energy values if `normalize_Energy` is True.
Args:
----
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
@ -837,16 +866,17 @@ class EnergyDataset:
normalize_Energy (bool):
Whether to normalize Energy values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
samples: Union[list[list], list[dict]],
ap: "AudioProcessor",
cache_path: str = None,
cache_path: Optional[str] = None,
precompute_num_workers=0,
normalize_energy=True,
):
) -> None:
self.samples = samples
self.ap = ap
self.cache_path = cache_path
@ -868,10 +898,10 @@ class EnergyDataset:
energy = self.normalize(energy)
return {"audio_unique_name": item["audio_unique_name"], "energy": energy}
def __len__(self):
def __len__(self) -> int:
return len(self.samples)
def precompute(self, num_workers=0):
def precompute(self, num_workers=0) -> None:
logger.info("Pre-computing energys...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
@ -879,7 +909,11 @@ class EnergyDataset:
normalize_energy = self.normalize_energy
self.normalize_energy = False
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
batch_size=batch_size,
dataset=self,
shuffle=False,
num_workers=num_workers,
collate_fn=self.collate_fn,
)
computed_data = []
for batch in dataloder:
@ -900,8 +934,7 @@ class EnergyDataset:
@staticmethod
def create_energy_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
energy_file = os.path.join(cache_path, file_name + "_energy.npy")
return energy_file
return os.path.join(cache_path, file_name + "_energy.npy")
@staticmethod
def _compute_and_save_energy(ap, wav_file, energy_file=None):
@ -917,7 +950,7 @@ class EnergyDataset:
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def load_stats(self, cache_path):
def load_stats(self, cache_path) -> None:
stats_path = os.path.join(cache_path, "energy_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
@ -938,9 +971,7 @@ class EnergyDataset:
return energy
def compute_or_load(self, wav_file, audio_unique_name):
"""
compute energy and return a numpy array of energy values
"""
"""Compute energy and return a numpy array of energy values."""
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(energy_file):
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)