Made the tqdm `progress_bar` objects of static download methods a static class variable

This commit is contained in:
Tessa Painter 2023-11-23 17:45:26 -06:00
parent 6fef4f9067
commit c3a746fee5
1 changed files with 23 additions and 22 deletions

View File

@ -26,7 +26,9 @@ LICENSE_URLS = {
}
class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json.
It provides an interface to list and download
models defines in '.model.json'
@ -109,7 +111,6 @@ class ModelManager(object):
def _list_for_model_type(self, model_type):
models_name_list = []
model_count = 1
model_type = "tts_models"
models_name_list.extend(self._list_models(model_type, model_count))
return models_name_list
@ -298,22 +299,22 @@ class ModelManager(object):
model_item = self.set_model_url(model_item)
return model_item, model_full_name, model, md5hash
def ask_tos(self, model_full_path):
@staticmethod
def ask_tos(model_full_path):
"""Ask the user to agree to the terms of service"""
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
if not os.path.exists(tos_path):
print(" > You must agree to the terms of service to use this model.")
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
print(' | > "I have read, understood and agreed the Terms and Conditions." - [y/n]')
answer = input(" | | > ")
if answer.lower() == "y":
with open(tos_path, "w") as f:
f.write("I have read, understood ad agree the Terms and Conditions.")
return True
else:
return False
print(" > You must agree to the terms of service to use this model.")
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
print(' | > "I have read, understood and agreed to the Terms and Conditions." - [y/n]')
answer = input(" | | > ")
if answer.lower() == "y":
with open(tos_path, "w", encoding="utf-8") as f:
f.write("I have read, understood and agreed to the Terms and Conditions.")
return True
return False
def tos_agreed(self, model_item, model_full_path):
@staticmethod
def tos_agreed(model_item, model_full_path):
"""Check if the user has agreed to the terms of service"""
if "tos_required" in model_item and model_item["tos_required"]:
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
@ -392,7 +393,7 @@ class ModelManager(object):
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
if "xtts" in model_name:
try:
self.check_if_configs_are_equal(model_name, model_item, output_path)
except:
@ -406,7 +407,7 @@ class ModelManager(object):
output_model_path = output_path
output_config_path = None
if (
model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json
@ -526,12 +527,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder)
@ -561,12 +562,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder)
@ -597,10 +598,10 @@ class ModelManager(object):
block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file:
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
@staticmethod