From 8c460d0cd066b29188dc9be3bb53cbc488545929 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 31 Jul 2024 15:20:56 +0200 Subject: [PATCH] fix(dataset): skip files where audio length can't be computed Avoids hard failures when the audio can't be decoded. --- TTS/tts/datasets/dataset.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 3886a8f8..f718f3d4 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -3,7 +3,7 @@ import collections import logging import os import random -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import torch @@ -46,15 +46,21 @@ def string2filename(string): return filename -def get_audio_size(audiopath) -> int: +def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: """Return the number of samples in the audio file.""" + if not isinstance(audiopath, str): + audiopath = str(audiopath) extension = audiopath.rpartition(".")[-1].lower() if extension not in {"mp3", "wav", "flac"}: raise RuntimeError( f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" ) - return torchaudio.info(audiopath).num_frames + try: + return torchaudio.info(audiopath).num_frames + except RuntimeError as e: + msg = f"Failed to decode {audiopath}" + raise RuntimeError(msg) from e class TTSDataset(Dataset): @@ -186,7 +192,11 @@ class TTSDataset(Dataset): lens = [] for item in self.samples: _, wav_file, *_ = _parse_sample(item) - audio_len = get_audio_size(wav_file) + try: + audio_len = get_audio_size(wav_file) + except RuntimeError: + logger.warn(f"Failed to compute length for {item['audio_file']}") + audio_len = 0 lens.append(audio_len) return lens @@ -304,7 +314,11 @@ class TTSDataset(Dataset): def _compute_lengths(samples): new_samples = [] for item in samples: - audio_length = get_audio_size(item["audio_file"]) + try: + audio_length = get_audio_size(item["audio_file"]) + except RuntimeError: + logger.warn(f"Failed to compute length, skipping {item['audio_file']}") + continue text_lenght = len(item["text"]) item["audio_length"] = audio_length item["text_length"] = text_lenght