From 8c460d0cd066b29188dc9be3bb53cbc488545929 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 31 Jul 2024 15:20:56 +0200 Subject: [PATCH 1/2] fix(dataset): skip files where audio length can't be computed Avoids hard failures when the audio can't be decoded. --- TTS/tts/datasets/dataset.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 3886a8f8..f718f3d4 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -3,7 +3,7 @@ import collections import logging import os import random -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import torch @@ -46,15 +46,21 @@ def string2filename(string): return filename -def get_audio_size(audiopath) -> int: +def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: """Return the number of samples in the audio file.""" + if not isinstance(audiopath, str): + audiopath = str(audiopath) extension = audiopath.rpartition(".")[-1].lower() if extension not in {"mp3", "wav", "flac"}: raise RuntimeError( f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" ) - return torchaudio.info(audiopath).num_frames + try: + return torchaudio.info(audiopath).num_frames + except RuntimeError as e: + msg = f"Failed to decode {audiopath}" + raise RuntimeError(msg) from e class TTSDataset(Dataset): @@ -186,7 +192,11 @@ class TTSDataset(Dataset): lens = [] for item in self.samples: _, wav_file, *_ = _parse_sample(item) - audio_len = get_audio_size(wav_file) + try: + audio_len = get_audio_size(wav_file) + except RuntimeError: + logger.warn(f"Failed to compute length for {item['audio_file']}") + audio_len = 0 lens.append(audio_len) return lens @@ -304,7 +314,11 @@ class TTSDataset(Dataset): def _compute_lengths(samples): new_samples = [] for item in samples: - audio_length = get_audio_size(item["audio_file"]) + try: + audio_length = get_audio_size(item["audio_file"]) + except RuntimeError: + logger.warn(f"Failed to compute length, skipping {item['audio_file']}") + continue text_lenght = len(item["text"]) item["audio_length"] = audio_length item["text_length"] = text_lenght From 9c604c1de0af05cc0863f687ec695d9b43864c4c Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 31 Jul 2024 15:40:46 +0200 Subject: [PATCH 2/2] chore(dataset): address lint issues --- TTS/tts/datasets/dataset.py | 201 +++++++++++++++++++----------------- 1 file changed, 109 insertions(+), 92 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index f718f3d4..37e3a177 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -3,9 +3,10 @@ import collections import logging import os import random -from typing import Any, Dict, List, Union +from typing import Any, Optional, Union import numpy as np +import numpy.typing as npt import torch import torchaudio import tqdm @@ -32,18 +33,18 @@ def _parse_sample(item): elif len(item) == 3: text, wav_file, speaker_name = item else: - raise ValueError(" [!] Dataset cannot parse the sample.") + msg = "Dataset cannot parse the sample." + raise ValueError(msg) return text, wav_file, speaker_name, language_name, attn_file -def noise_augment_audio(wav): +def noise_augment_audio(wav: npt.NDArray) -> npt.NDArray: return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) -def string2filename(string): +def string2filename(string: str) -> str: # generate a safe and reversible filename based on a string - filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") - return filename + return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: @@ -52,9 +53,8 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: audiopath = str(audiopath) extension = audiopath.rpartition(".")[-1].lower() if extension not in {"mp3", "wav", "flac"}: - raise RuntimeError( - f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" - ) + msg = f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" + raise RuntimeError(msg) try: return torchaudio.info(audiopath).num_frames @@ -69,31 +69,32 @@ class TTSDataset(Dataset): outputs_per_step: int = 1, compute_linear_spec: bool = False, ap: AudioProcessor = None, - samples: List[Dict] = None, + samples: Optional[list[dict]] = None, tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, compute_energy: bool = False, - f0_cache_path: str = None, - energy_cache_path: str = None, + f0_cache_path: Optional[str] = None, + energy_cache_path: Optional[str] = None, return_wav: bool = False, batch_group_size: int = 0, min_text_len: int = 0, max_text_len: int = float("inf"), min_audio_len: int = 0, max_audio_len: int = float("inf"), - phoneme_cache_path: str = None, + phoneme_cache_path: Optional[str] = None, precompute_num_workers: int = 0, - speaker_id_mapping: Dict = None, - d_vector_mapping: Dict = None, - language_id_mapping: Dict = None, + speaker_id_mapping: Optional[dict] = None, + d_vector_mapping: Optional[dict] = None, + language_id_mapping: Optional[dict] = None, use_noise_augment: bool = False, start_by_longest: bool = False, - ): + ) -> None: """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. If you need something different, you can subclass and override. Args: + ---- outputs_per_step (int): Number of time frames predicted per step. compute_linear_spec (bool): compute linear spectrogram if True. @@ -145,6 +146,7 @@ class TTSDataset(Dataset): use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + """ super().__init__() self.batch_group_size = batch_group_size @@ -174,28 +176,37 @@ class TTSDataset(Dataset): if self.tokenizer.use_phonemes: self.phoneme_dataset = PhonemeDataset( - self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + self.samples, + self.tokenizer, + phoneme_cache_path, + precompute_num_workers=precompute_num_workers, ) if compute_f0: self.f0_dataset = F0Dataset( - self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + self.samples, + self.ap, + cache_path=f0_cache_path, + precompute_num_workers=precompute_num_workers, ) if compute_energy: self.energy_dataset = EnergyDataset( - self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers + self.samples, + self.ap, + cache_path=energy_cache_path, + precompute_num_workers=precompute_num_workers, ) self.print_logs() @property - def lengths(self): + def lengths(self) -> list[int]: lens = [] for item in self.samples: _, wav_file, *_ = _parse_sample(item) try: audio_len = get_audio_size(wav_file) except RuntimeError: - logger.warn(f"Failed to compute length for {item['audio_file']}") + logger.warning(f"Failed to compute length for {item['audio_file']}") audio_len = 0 lens.append(audio_len) return lens @@ -205,7 +216,7 @@ class TTSDataset(Dataset): return self._samples @samples.setter - def samples(self, new_samples): + def samples(self, new_samples) -> None: self._samples = new_samples if hasattr(self, "f0_dataset"): self.f0_dataset.samples = new_samples @@ -214,7 +225,7 @@ class TTSDataset(Dataset): if hasattr(self, "phoneme_dataset"): self.phoneme_dataset.samples = new_samples - def __len__(self): + def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx): @@ -261,7 +272,7 @@ class TTSDataset(Dataset): token_ids = self.tokenizer.text_to_ids(text) return np.array(token_ids, dtype=np.int32) - def load_data(self, idx): + def load_data(self, idx) -> dict[str, Any]: item = self.samples[idx] raw_text = item["text"] @@ -295,7 +306,7 @@ class TTSDataset(Dataset): if self.compute_energy: energy = self.get_energy(idx)["energy"] - sample = { + return { "raw_text": raw_text, "token_ids": token_ids, "wav": wav, @@ -308,7 +319,6 @@ class TTSDataset(Dataset): "wav_file_name": os.path.basename(item["audio_file"]), "audio_unique_name": item["audio_unique_name"], } - return sample @staticmethod def _compute_lengths(samples): @@ -317,7 +327,7 @@ class TTSDataset(Dataset): try: audio_length = get_audio_size(item["audio_file"]) except RuntimeError: - logger.warn(f"Failed to compute length, skipping {item['audio_file']}") + logger.warning(f"Failed to compute length, skipping {item['audio_file']}") continue text_lenght = len(item["text"]) item["audio_length"] = audio_length @@ -326,7 +336,7 @@ class TTSDataset(Dataset): return new_samples @staticmethod - def filter_by_length(lengths: List[int], min_len: int, max_len: int): + def filter_by_length(lengths: list[int], min_len: int, max_len: int): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] @@ -339,10 +349,9 @@ class TTSDataset(Dataset): return ignore_idx, keep_idx @staticmethod - def sort_by_length(samples: List[List]): + def sort_by_length(samples: list[list]): audio_lengths = [s["audio_length"] for s in samples] - idxs = np.argsort(audio_lengths) # ascending order - return idxs + return np.argsort(audio_lengths) # ascending order @staticmethod def create_buckets(samples, batch_group_size: int): @@ -362,7 +371,7 @@ class TTSDataset(Dataset): samples_new.append(samples[idx]) return samples_new - def preprocess_samples(self): + def preprocess_samples(self) -> None: r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. """ @@ -388,7 +397,8 @@ class TTSDataset(Dataset): samples = self._select_samples_by_idx(sorted_idxs, samples) if len(samples) == 0: - raise RuntimeError(" [!] No samples left") + msg = "No samples left." + raise RuntimeError(msg) # shuffle batch groups # create batches with similar length items @@ -402,36 +412,37 @@ class TTSDataset(Dataset): self.samples = samples logger.info("Preprocessing samples") - logger.info("Max text length: {}".format(np.max(text_lengths))) - logger.info("Min text length: {}".format(np.min(text_lengths))) - logger.info("Avg text length: {}".format(np.mean(text_lengths))) - logger.info("Max audio length: {}".format(np.max(audio_lengths))) - logger.info("Min audio length: {}".format(np.min(audio_lengths))) - logger.info("Avg audio length: {}".format(np.mean(audio_lengths))) + logger.info(f"Max text length: {np.max(text_lengths)}") + logger.info(f"Min text length: {np.min(text_lengths)}") + logger.info(f"Avg text length: {np.mean(text_lengths)}") + logger.info(f"Max audio length: {np.max(audio_lengths)}") + logger.info(f"Min audio length: {np.min(audio_lengths)}") + logger.info(f"Avg audio length: {np.mean(audio_lengths)}") logger.info("Num. instances discarded samples: %d", len(ignore_idx)) - logger.info("Batch group size: {}.".format(self.batch_group_size)) + logger.info(f"Batch group size: {self.batch_group_size}.") @staticmethod def _sort_batch(batch, text_lengths): """Sort the batch by the input text length for RNN efficiency. Args: + ---- batch (Dict): Batch returned by `__getitem__`. text_lengths (List[int]): Lengths of the input character sequences. + """ text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) batch = [batch[idx] for idx in ids_sorted_decreasing] return batch, text_lengths, ids_sorted_decreasing def collate_fn(self, batch): - r""" - Perform preprocessing and create a final data batch: + """Perform preprocessing and create a final data batch. + 1. Sort batch instances by text-length 2. Convert Audio signal to features. 3. PAD sequences wrt r. 4. Load to Torch. """ - # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) @@ -576,23 +587,18 @@ class TTSDataset(Dataset): "audio_unique_names": batch["audio_unique_name"], } - raise TypeError( - ( - "batch must contain tensors, numbers, dicts or lists;\ - found {}".format( - type(batch[0]) - ) - ) - ) + msg = f"batch must contain tensors, numbers, dicts or lists; found {type(batch[0])}" + raise TypeError(msg) class PhonemeDataset(Dataset): - """Phoneme Dataset for converting input text to phonemes and then token IDs + """Phoneme Dataset for converting input text to phonemes and then token IDs. At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data loading latency. If `cache_path` is already present, it skips the pre-computation. Args: + ---- samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. @@ -604,15 +610,16 @@ class PhonemeDataset(Dataset): precompute_num_workers (int): Number of workers used for pre-computing the phonemes. Defaults to 0. + """ def __init__( self, - samples: Union[List[Dict], List[List]], + samples: Union[list[dict], list[list]], tokenizer: "TTSTokenizer", cache_path: str, - precompute_num_workers=0, - ): + precompute_num_workers: int = 0, + ) -> None: self.samples = samples self.tokenizer = tokenizer self.cache_path = cache_path @@ -620,16 +627,16 @@ class PhonemeDataset(Dataset): os.makedirs(cache_path) self.precompute(precompute_num_workers) - def __getitem__(self, index): + def __getitem__(self, index) -> dict[str, Any]: item = self.samples[index] ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"]) ph_hat = self.tokenizer.ids_to_text(ids) return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} - def __len__(self): + def __len__(self) -> int: return len(self.samples) - def compute_or_load(self, file_name, text, language): + def compute_or_load(self, file_name: str, text: str, language: str) -> list[int]: """Compute phonemes for the given text. If the phonemes are already cached, load them from cache. @@ -643,11 +650,11 @@ class PhonemeDataset(Dataset): np.save(cache_path, ids) return ids - def get_pad_id(self): - """Get pad token ID for sequence padding""" + def get_pad_id(self) -> int: + """Get pad token ID for sequence padding.""" return self.tokenizer.pad_id - def precompute(self, num_workers=1): + def precompute(self, num_workers: int = 1) -> None: """Precompute phonemes for all samples. We use pytorch dataloader because we are lazy. @@ -656,7 +663,11 @@ class PhonemeDataset(Dataset): with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 dataloder = torch.utils.data.DataLoader( - batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, ) for _ in dataloder: pbar.update(batch_size) @@ -681,12 +692,13 @@ class PhonemeDataset(Dataset): class F0Dataset: - """F0 Dataset for computing F0 from wav files in CPU + """F0 Dataset for computing F0 from wav files in CPU. Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It also computes the mean and std of F0 values if `normalize_f0` is True. Args: + ---- samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. @@ -702,17 +714,18 @@ class F0Dataset: normalize_f0 (bool): Whether to normalize F0 values by mean and std. Defaults to True. + """ def __init__( self, - samples: Union[List[List], List[Dict]], + samples: Union[list[list], list[dict]], ap: "AudioProcessor", audio_config=None, # pylint: disable=unused-argument - cache_path: str = None, - precompute_num_workers=0, - normalize_f0=True, - ): + cache_path: Optional[str] = None, + precompute_num_workers: int = 0, + normalize_f0: bool = True, + ) -> None: self.samples = samples self.ap = ap self.cache_path = cache_path @@ -734,10 +747,10 @@ class F0Dataset: f0 = self.normalize(f0) return {"audio_unique_name": item["audio_unique_name"], "f0": f0} - def __len__(self): + def __len__(self) -> int: return len(self.samples) - def precompute(self, num_workers=0): + def precompute(self, num_workers: int = 0) -> None: logger.info("Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 @@ -745,7 +758,11 @@ class F0Dataset: normalize_f0 = self.normalize_f0 self.normalize_f0 = False dataloder = torch.utils.data.DataLoader( - batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, ) computed_data = [] for batch in dataloder: @@ -764,9 +781,8 @@ class F0Dataset: return self.pad_id @staticmethod - def create_pitch_file_path(file_name, cache_path): - pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") - return pitch_file + def create_pitch_file_path(file_name: str, cache_path: str) -> str: + return os.path.join(cache_path, file_name + "_pitch.npy") @staticmethod def _compute_and_save_pitch(ap, wav_file, pitch_file=None): @@ -782,7 +798,7 @@ class F0Dataset: mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std - def load_stats(self, cache_path): + def load_stats(self, cache_path) -> None: stats_path = os.path.join(cache_path, "pitch_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() self.mean = stats["mean"].astype(np.float32) @@ -803,9 +819,7 @@ class F0Dataset: return pitch def compute_or_load(self, wav_file, audio_unique_name): - """ - compute pitch and return a numpy array of pitch values - """ + """Compute pitch and return a numpy array of pitch values.""" pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) if not os.path.exists(pitch_file): pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) @@ -830,12 +844,13 @@ class F0Dataset: class EnergyDataset: - """Energy Dataset for computing Energy from wav files in CPU + """Energy Dataset for computing Energy from wav files in CPU. Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It also computes the mean and std of Energy values if `normalize_Energy` is True. Args: + ---- samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. @@ -851,16 +866,17 @@ class EnergyDataset: normalize_Energy (bool): Whether to normalize Energy values by mean and std. Defaults to True. + """ def __init__( self, - samples: Union[List[List], List[Dict]], + samples: Union[list[list], list[dict]], ap: "AudioProcessor", - cache_path: str = None, + cache_path: Optional[str] = None, precompute_num_workers=0, normalize_energy=True, - ): + ) -> None: self.samples = samples self.ap = ap self.cache_path = cache_path @@ -882,10 +898,10 @@ class EnergyDataset: energy = self.normalize(energy) return {"audio_unique_name": item["audio_unique_name"], "energy": energy} - def __len__(self): + def __len__(self) -> int: return len(self.samples) - def precompute(self, num_workers=0): + def precompute(self, num_workers=0) -> None: logger.info("Pre-computing energys...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 @@ -893,7 +909,11 @@ class EnergyDataset: normalize_energy = self.normalize_energy self.normalize_energy = False dataloder = torch.utils.data.DataLoader( - batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, ) computed_data = [] for batch in dataloder: @@ -914,8 +934,7 @@ class EnergyDataset: @staticmethod def create_energy_file_path(wav_file, cache_path): file_name = os.path.splitext(os.path.basename(wav_file))[0] - energy_file = os.path.join(cache_path, file_name + "_energy.npy") - return energy_file + return os.path.join(cache_path, file_name + "_energy.npy") @staticmethod def _compute_and_save_energy(ap, wav_file, energy_file=None): @@ -931,7 +950,7 @@ class EnergyDataset: mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std - def load_stats(self, cache_path): + def load_stats(self, cache_path) -> None: stats_path = os.path.join(cache_path, "energy_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() self.mean = stats["mean"].astype(np.float32) @@ -952,9 +971,7 @@ class EnergyDataset: return energy def compute_or_load(self, wav_file, audio_unique_name): - """ - compute energy and return a numpy array of energy values - """ + """Compute energy and return a numpy array of energy values.""" energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) if not os.path.exists(energy_file): energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)