mirror of https://github.com/coqui-ai/TTS.git
fix(dataset): skip files where audio length can't be computed
Avoids hard failures when the audio can't be decoded.
This commit is contained in:
parent
20bbb411c2
commit
8c460d0cd0
|
@ -3,7 +3,7 @@ import collections
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -46,15 +46,21 @@ def string2filename(string):
|
|||
return filename
|
||||
|
||||
|
||||
def get_audio_size(audiopath) -> int:
|
||||
def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
|
||||
"""Return the number of samples in the audio file."""
|
||||
if not isinstance(audiopath, str):
|
||||
audiopath = str(audiopath)
|
||||
extension = audiopath.rpartition(".")[-1].lower()
|
||||
if extension not in {"mp3", "wav", "flac"}:
|
||||
raise RuntimeError(
|
||||
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
|
||||
)
|
||||
|
||||
return torchaudio.info(audiopath).num_frames
|
||||
try:
|
||||
return torchaudio.info(audiopath).num_frames
|
||||
except RuntimeError as e:
|
||||
msg = f"Failed to decode {audiopath}"
|
||||
raise RuntimeError(msg) from e
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
|
@ -186,7 +192,11 @@ class TTSDataset(Dataset):
|
|||
lens = []
|
||||
for item in self.samples:
|
||||
_, wav_file, *_ = _parse_sample(item)
|
||||
audio_len = get_audio_size(wav_file)
|
||||
try:
|
||||
audio_len = get_audio_size(wav_file)
|
||||
except RuntimeError:
|
||||
logger.warn(f"Failed to compute length for {item['audio_file']}")
|
||||
audio_len = 0
|
||||
lens.append(audio_len)
|
||||
return lens
|
||||
|
||||
|
@ -304,7 +314,11 @@ class TTSDataset(Dataset):
|
|||
def _compute_lengths(samples):
|
||||
new_samples = []
|
||||
for item in samples:
|
||||
audio_length = get_audio_size(item["audio_file"])
|
||||
try:
|
||||
audio_length = get_audio_size(item["audio_file"])
|
||||
except RuntimeError:
|
||||
logger.warn(f"Failed to compute length, skipping {item['audio_file']}")
|
||||
continue
|
||||
text_lenght = len(item["text"])
|
||||
item["audio_length"] = audio_length
|
||||
item["text_length"] = text_lenght
|
||||
|
|
Loading…
Reference in New Issue