mirror of https://github.com/coqui-ai/TTS.git
fix(dataset): skip files where audio length can't be computed
Avoids hard failures when the audio can't be decoded.
This commit is contained in:
parent
20bbb411c2
commit
8c460d0cd0
|
@ -3,7 +3,7 @@ import collections
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -46,15 +46,21 @@ def string2filename(string):
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
def get_audio_size(audiopath) -> int:
|
def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
|
||||||
"""Return the number of samples in the audio file."""
|
"""Return the number of samples in the audio file."""
|
||||||
|
if not isinstance(audiopath, str):
|
||||||
|
audiopath = str(audiopath)
|
||||||
extension = audiopath.rpartition(".")[-1].lower()
|
extension = audiopath.rpartition(".")[-1].lower()
|
||||||
if extension not in {"mp3", "wav", "flac"}:
|
if extension not in {"mp3", "wav", "flac"}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
|
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
return torchaudio.info(audiopath).num_frames
|
return torchaudio.info(audiopath).num_frames
|
||||||
|
except RuntimeError as e:
|
||||||
|
msg = f"Failed to decode {audiopath}"
|
||||||
|
raise RuntimeError(msg) from e
|
||||||
|
|
||||||
|
|
||||||
class TTSDataset(Dataset):
|
class TTSDataset(Dataset):
|
||||||
|
@ -186,7 +192,11 @@ class TTSDataset(Dataset):
|
||||||
lens = []
|
lens = []
|
||||||
for item in self.samples:
|
for item in self.samples:
|
||||||
_, wav_file, *_ = _parse_sample(item)
|
_, wav_file, *_ = _parse_sample(item)
|
||||||
|
try:
|
||||||
audio_len = get_audio_size(wav_file)
|
audio_len = get_audio_size(wav_file)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warn(f"Failed to compute length for {item['audio_file']}")
|
||||||
|
audio_len = 0
|
||||||
lens.append(audio_len)
|
lens.append(audio_len)
|
||||||
return lens
|
return lens
|
||||||
|
|
||||||
|
@ -304,7 +314,11 @@ class TTSDataset(Dataset):
|
||||||
def _compute_lengths(samples):
|
def _compute_lengths(samples):
|
||||||
new_samples = []
|
new_samples = []
|
||||||
for item in samples:
|
for item in samples:
|
||||||
|
try:
|
||||||
audio_length = get_audio_size(item["audio_file"])
|
audio_length = get_audio_size(item["audio_file"])
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warn(f"Failed to compute length, skipping {item['audio_file']}")
|
||||||
|
continue
|
||||||
text_lenght = len(item["text"])
|
text_lenght = len(item["text"])
|
||||||
item["audio_length"] = audio_length
|
item["audio_length"] = audio_length
|
||||||
item["text_length"] = text_lenght
|
item["text_length"] = text_lenght
|
||||||
|
|
Loading…
Reference in New Issue