refactor(dataset): get audio length with torchaudio

Removes a (GPL) dependency
This commit is contained in:
Enno Hermann 2024-03-14 20:48:29 +01:00
parent e5c6da1c98
commit adbcba06da
2 changed files with 4 additions and 5 deletions

View File

@ -4,9 +4,9 @@ import os
import random import random
from typing import Dict, List, Union from typing import Dict, List, Union
import mutagen
import numpy as np import numpy as np
import torch import torch
import torchaudio
import tqdm import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -43,15 +43,15 @@ def string2filename(string):
return filename return filename
def get_audio_size(audiopath): def get_audio_size(audiopath) -> int:
"""Return the number of samples in the audio file."""
extension = audiopath.rpartition(".")[-1].lower() extension = audiopath.rpartition(".")[-1].lower()
if extension not in {"mp3", "wav", "flac"}: if extension not in {"mp3", "wav", "flac"}:
raise RuntimeError( raise RuntimeError(
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
) )
audio_info = mutagen.File(audiopath).info return torchaudio.info(audiopath).num_frames
return int(audio_info.length * audio_info.sample_rate)
class TTSDataset(Dataset): class TTSDataset(Dataset):

View File

@ -12,7 +12,6 @@ anyascii>=0.3.0
pyyaml>=6.0 pyyaml>=6.0
fsspec[http]>=2023.6.0 # <= 2023.9.1 makes aux tests fail fsspec[http]>=2023.6.0 # <= 2023.9.1 makes aux tests fail
packaging>=23.1 packaging>=23.1
mutagen==1.47.0
# deps for inference # deps for inference
pysbd>=0.3.4 pysbd>=0.3.4
# deps for notebooks # deps for notebooks