fix(dataset): skip files where audio length can't be computed

Avoids hard failures when the audio can't be decoded.
This commit is contained in:
Enno Hermann 2024-07-31 15:20:56 +02:00
parent 20bbb411c2
commit 8c460d0cd0
1 changed files with 19 additions and 5 deletions

View File

@ -3,7 +3,7 @@ import collections
import logging import logging
import os import os
import random import random
from typing import Dict, List, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import torch import torch
@ -46,15 +46,21 @@ def string2filename(string):
return filename return filename
def get_audio_size(audiopath) -> int: def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
"""Return the number of samples in the audio file.""" """Return the number of samples in the audio file."""
if not isinstance(audiopath, str):
audiopath = str(audiopath)
extension = audiopath.rpartition(".")[-1].lower() extension = audiopath.rpartition(".")[-1].lower()
if extension not in {"mp3", "wav", "flac"}: if extension not in {"mp3", "wav", "flac"}:
raise RuntimeError( raise RuntimeError(
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
) )
try:
return torchaudio.info(audiopath).num_frames return torchaudio.info(audiopath).num_frames
except RuntimeError as e:
msg = f"Failed to decode {audiopath}"
raise RuntimeError(msg) from e
class TTSDataset(Dataset): class TTSDataset(Dataset):
@ -186,7 +192,11 @@ class TTSDataset(Dataset):
lens = [] lens = []
for item in self.samples: for item in self.samples:
_, wav_file, *_ = _parse_sample(item) _, wav_file, *_ = _parse_sample(item)
try:
audio_len = get_audio_size(wav_file) audio_len = get_audio_size(wav_file)
except RuntimeError:
logger.warn(f"Failed to compute length for {item['audio_file']}")
audio_len = 0
lens.append(audio_len) lens.append(audio_len)
return lens return lens
@ -304,7 +314,11 @@ class TTSDataset(Dataset):
def _compute_lengths(samples): def _compute_lengths(samples):
new_samples = [] new_samples = []
for item in samples: for item in samples:
try:
audio_length = get_audio_size(item["audio_file"]) audio_length = get_audio_size(item["audio_file"])
except RuntimeError:
logger.warn(f"Failed to compute length, skipping {item['audio_file']}")
continue
text_lenght = len(item["text"]) text_lenght = len(item["text"])
item["audio_length"] = audio_length item["audio_length"] = audio_length
item["text_length"] = text_lenght item["text_length"] = text_lenght