Made the tqdm `progress_bar` objects of static download methods a static class variable

This commit is contained in:
Tessa Painter 2023-11-23 17:45:26 -06:00
parent 6fef4f9067
commit c3a746fee5
1 changed files with 23 additions and 22 deletions

View File

@ -26,7 +26,9 @@ LICENSE_URLS = {
} }
class ModelManager(object): class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json. """Manage TTS models defined in .models.json.
It provides an interface to list and download It provides an interface to list and download
models defines in '.model.json' models defines in '.model.json'
@ -109,7 +111,6 @@ class ModelManager(object):
def _list_for_model_type(self, model_type): def _list_for_model_type(self, model_type):
models_name_list = [] models_name_list = []
model_count = 1 model_count = 1
model_type = "tts_models"
models_name_list.extend(self._list_models(model_type, model_count)) models_name_list.extend(self._list_models(model_type, model_count))
return models_name_list return models_name_list
@ -298,22 +299,22 @@ class ModelManager(object):
model_item = self.set_model_url(model_item) model_item = self.set_model_url(model_item)
return model_item, model_full_name, model, md5hash return model_item, model_full_name, model, md5hash
def ask_tos(self, model_full_path): @staticmethod
def ask_tos(model_full_path):
"""Ask the user to agree to the terms of service""" """Ask the user to agree to the terms of service"""
tos_path = os.path.join(model_full_path, "tos_agreed.txt") tos_path = os.path.join(model_full_path, "tos_agreed.txt")
if not os.path.exists(tos_path):
print(" > You must agree to the terms of service to use this model.") print(" > You must agree to the terms of service to use this model.")
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt") print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
print(' | > "I have read, understood and agreed the Terms and Conditions." - [y/n]') print(' | > "I have read, understood and agreed to the Terms and Conditions." - [y/n]')
answer = input(" | | > ") answer = input(" | | > ")
if answer.lower() == "y": if answer.lower() == "y":
with open(tos_path, "w") as f: with open(tos_path, "w", encoding="utf-8") as f:
f.write("I have read, understood ad agree the Terms and Conditions.") f.write("I have read, understood and agreed to the Terms and Conditions.")
return True return True
else:
return False return False
def tos_agreed(self, model_item, model_full_path): @staticmethod
def tos_agreed(model_item, model_full_path):
"""Check if the user has agreed to the terms of service""" """Check if the user has agreed to the terms of service"""
if "tos_required" in model_item and model_item["tos_required"]: if "tos_required" in model_item and model_item["tos_required"]:
tos_path = os.path.join(model_full_path, "tos_agreed.txt") tos_path = os.path.join(model_full_path, "tos_agreed.txt")
@ -392,7 +393,7 @@ class ModelManager(object):
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it # if the configs are different, redownload it
# ToDo: we need a better way to handle it # ToDo: we need a better way to handle it
if "xtts_v1" in model_name: if "xtts" in model_name:
try: try:
self.check_if_configs_are_equal(model_name, model_item, output_path) self.check_if_configs_are_equal(model_name, model_item, output_path)
except: except:
@ -406,7 +407,7 @@ class ModelManager(object):
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None
if ( if (
model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
): # TODO:This is stupid but don't care for now. ): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json # update paths in the config.json
@ -526,12 +527,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0)) total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
if progress_bar: if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_zip_name, "wb") as file: with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
progress_bar.update(len(data)) ModelManager.tqdm_progress.update(len(data))
file.write(data) file.write(data)
with zipfile.ZipFile(temp_zip_name) as z: with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder) z.extractall(output_folder)
@ -561,12 +562,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0)) total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
if progress_bar: if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file: with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
progress_bar.update(len(data)) ModelManager.tqdm_progress.update(len(data))
file.write(data) file.write(data)
with tarfile.open(temp_tar_name) as t: with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder) t.extractall(output_folder)
@ -597,10 +598,10 @@ class ModelManager(object):
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file: with open(temp_zip_name, "wb") as file:
if progress_bar: if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
progress_bar.update(len(data)) ModelManager.tqdm_progress.update(len(data))
file.write(data) file.write(data)
@staticmethod @staticmethod