fix(dataset): skip files where audio length can't be computed

Avoids hard failures when the audio can't be decoded.
This commit is contained in:
Enno Hermann 2024-07-31 15:20:56 +02:00
parent 20bbb411c2
commit 8c460d0cd0
1 changed files with 19 additions and 5 deletions

View File

@ -3,7 +3,7 @@ import collections
import logging
import os
import random
from typing import Dict, List, Union
from typing import Any, Dict, List, Union
import numpy as np
import torch
@ -46,15 +46,21 @@ def string2filename(string):
return filename
def get_audio_size(audiopath) -> int:
def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
"""Return the number of samples in the audio file."""
if not isinstance(audiopath, str):
audiopath = str(audiopath)
extension = audiopath.rpartition(".")[-1].lower()
if extension not in {"mp3", "wav", "flac"}:
raise RuntimeError(
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
)
return torchaudio.info(audiopath).num_frames
try:
return torchaudio.info(audiopath).num_frames
except RuntimeError as e:
msg = f"Failed to decode {audiopath}"
raise RuntimeError(msg) from e
class TTSDataset(Dataset):
@ -186,7 +192,11 @@ class TTSDataset(Dataset):
lens = []
for item in self.samples:
_, wav_file, *_ = _parse_sample(item)
audio_len = get_audio_size(wav_file)
try:
audio_len = get_audio_size(wav_file)
except RuntimeError:
logger.warn(f"Failed to compute length for {item['audio_file']}")
audio_len = 0
lens.append(audio_len)
return lens
@ -304,7 +314,11 @@ class TTSDataset(Dataset):
def _compute_lengths(samples):
new_samples = []
for item in samples:
audio_length = get_audio_size(item["audio_file"])
try:
audio_length = get_audio_size(item["audio_file"])
except RuntimeError:
logger.warn(f"Failed to compute length, skipping {item['audio_file']}")
continue
text_lenght = len(item["text"])
item["audio_length"] = audio_length
item["text_length"] = text_lenght