Made the tqdm `progress_bar` objects of static download methods a static class variable (#3297)

This commit is contained in:
Tessa Painter 2023-11-24 05:23:59 -06:00 committed by GitHub
parent b47d9c6e36
commit 64f391b583
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 6 deletions

View File

@ -26,7 +26,9 @@ LICENSE_URLS = {
} }
class ModelManager(object): class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json. """Manage TTS models defined in .models.json.
It provides an interface to list and download It provides an interface to list and download
models defines in '.model.json' models defines in '.model.json'
@ -525,12 +527,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0)) total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
if progress_bar: if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_zip_name, "wb") as file: with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
progress_bar.update(len(data)) ModelManager.tqdm_progress.update(len(data))
file.write(data) file.write(data)
with zipfile.ZipFile(temp_zip_name) as z: with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder) z.extractall(output_folder)
@ -560,12 +562,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0)) total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
if progress_bar: if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file: with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
progress_bar.update(len(data)) ModelManager.tqdm_progress.update(len(data))
file.write(data) file.write(data)
with tarfile.open(temp_tar_name) as t: with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder) t.extractall(output_folder)
@ -596,10 +598,10 @@ class ModelManager(object):
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file: with open(temp_zip_name, "wb") as file:
if progress_bar: if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
progress_bar.update(len(data)) ModelManager.tqdm_progress.update(len(data))
file.write(data) file.write(data)
@staticmethod @staticmethod