mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #21 from eginhard/audio-length
refactor(dataset): get audio length with torchaudio
This commit is contained in:
commit
571f065994
|
@ -4,9 +4,9 @@ import os
|
|||
import random
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import mutagen
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
@ -43,15 +43,15 @@ def string2filename(string):
|
|||
return filename
|
||||
|
||||
|
||||
def get_audio_size(audiopath):
|
||||
def get_audio_size(audiopath) -> int:
|
||||
"""Return the number of samples in the audio file."""
|
||||
extension = audiopath.rpartition(".")[-1].lower()
|
||||
if extension not in {"mp3", "wav", "flac"}:
|
||||
raise RuntimeError(
|
||||
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
|
||||
)
|
||||
|
||||
audio_info = mutagen.File(audiopath).info
|
||||
return int(audio_info.length * audio_info.sample_rate)
|
||||
return torchaudio.info(audiopath).num_frames
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
|
|
|
@ -12,7 +12,6 @@ anyascii>=0.3.0
|
|||
pyyaml>=6.0
|
||||
fsspec[http]>=2023.6.0 # <= 2023.9.1 makes aux tests fail
|
||||
packaging>=23.1
|
||||
mutagen==1.47.0
|
||||
# deps for inference
|
||||
pysbd>=0.3.4
|
||||
# deps for notebooks
|
||||
|
|
Loading…
Reference in New Issue