mirror of https://github.com/coqui-ai/TTS.git
Add initial data downloaders for stt
This commit is contained in:
parent
4157e99d2d
commit
89cbfbc829
|
@ -0,0 +1,185 @@
|
|||
# Adapted from https://github.com/pytorch/audio/
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from torch.utils.model_zoo import tqdm
|
||||
|
||||
|
||||
def stream_url(
|
||||
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
|
||||
) -> Iterable:
|
||||
"""Stream url by chunk
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
|
||||
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
"""
|
||||
|
||||
# If we already have the whole file, there is no need to download it again
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
url_size = int(response.info().get("Content-Length", -1))
|
||||
if url_size == start_byte:
|
||||
return
|
||||
|
||||
req = urllib.request.Request(url)
|
||||
if start_byte:
|
||||
req.headers["Range"] = "bytes={}-".format(start_byte)
|
||||
|
||||
with urllib.request.urlopen(req) as upointer, tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=url_size,
|
||||
disable=not progress_bar,
|
||||
) as pbar:
|
||||
|
||||
num_bytes = 0
|
||||
while True:
|
||||
chunk = upointer.read(block_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
num_bytes += len(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str,
|
||||
download_folder: str,
|
||||
filename: Optional[str] = None,
|
||||
hash_value: Optional[str] = None,
|
||||
hash_type: str = "sha256",
|
||||
progress_bar: bool = True,
|
||||
resume: bool = False,
|
||||
) -> None:
|
||||
"""Download file to disk.
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
download_folder (str): Folder to download file.
|
||||
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
|
||||
(Default: ``None``).
|
||||
hash_value (str or None, optional): Hash for url (Default: ``None``).
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
resume (bool, optional): Enable resuming download (Default: ``False``).
|
||||
"""
|
||||
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
req_info = urllib.request.urlopen(req).info()
|
||||
|
||||
# Detect filename
|
||||
filename = filename or req_info.get_filename() or os.path.basename(url)
|
||||
filepath = os.path.join(download_folder, filename)
|
||||
if resume and os.path.exists(filepath):
|
||||
mode = "ab"
|
||||
local_size: Optional[int] = os.path.getsize(filepath)
|
||||
|
||||
elif not resume and os.path.exists(filepath):
|
||||
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
|
||||
else:
|
||||
mode = "wb"
|
||||
local_size = None
|
||||
|
||||
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if validate_file(file_obj, hash_value, hash_type):
|
||||
return
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
with open(filepath, mode) as fpointer:
|
||||
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
|
||||
fpointer.write(chunk)
|
||||
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if hash_value and not validate_file(file_obj, hash_value, hash_type):
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
|
||||
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
|
||||
"""Validate a given file object with its hash.
|
||||
|
||||
Args:
|
||||
file_obj: File object to read from.
|
||||
hash_value (str): Hash for url.
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
|
||||
Returns:
|
||||
bool: return True if its a valid file, else False.
|
||||
"""
|
||||
|
||||
if hash_type == "sha256":
|
||||
hash_func = hashlib.sha256()
|
||||
elif hash_type == "md5":
|
||||
hash_func = hashlib.md5()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
while True:
|
||||
# Read by chunk to avoid filling memory
|
||||
chunk = file_obj.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
hash_func.update(chunk)
|
||||
|
||||
return hash_func.hexdigest() == hash_value
|
||||
|
||||
|
||||
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
|
||||
"""Extract archive.
|
||||
Args:
|
||||
from_path (str): the path of the archive.
|
||||
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
|
||||
(Default: ``None``)
|
||||
overwrite (bool, optional): overwrite existing files (Default: ``False``)
|
||||
|
||||
Returns:
|
||||
list: List of paths to extracted files even if not overwritten.
|
||||
"""
|
||||
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
try:
|
||||
with tarfile.open(from_path, "r") as tar:
|
||||
logging.info("Opened tar file {}.".format(from_path))
|
||||
files = []
|
||||
for file_ in tar: # type: Any
|
||||
file_path = os.path.join(to_path, file_.name)
|
||||
if file_.isfile():
|
||||
files.append(file_path)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("{} already extracted.".format(file_path))
|
||||
if not overwrite:
|
||||
continue
|
||||
tar.extract(file_, to_path)
|
||||
return files
|
||||
except tarfile.ReadError:
|
||||
pass
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(from_path, "r") as zfile:
|
||||
logging.info("Opened zip file {}.".format(from_path))
|
||||
files = zfile.namelist()
|
||||
for file_ in files:
|
||||
file_path = os.path.join(to_path, file_)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("{} already extracted.".format(file_path))
|
||||
if not overwrite:
|
||||
continue
|
||||
zfile.extract(file_, to_path)
|
||||
return files
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
|
||||
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
|
Loading…
Reference in New Issue