mirror of https://github.com/coqui-ai/TTS.git
Made the tqdm `progress_bar` objects of static download methods a static class variable
This commit is contained in:
parent
6fef4f9067
commit
c3a746fee5
|
@ -26,7 +26,9 @@ LICENSE_URLS = {
|
|||
}
|
||||
|
||||
|
||||
|
||||
class ModelManager(object):
|
||||
tqdm_progress = None
|
||||
"""Manage TTS models defined in .models.json.
|
||||
It provides an interface to list and download
|
||||
models defines in '.model.json'
|
||||
|
@ -109,7 +111,6 @@ class ModelManager(object):
|
|||
def _list_for_model_type(self, model_type):
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
model_type = "tts_models"
|
||||
models_name_list.extend(self._list_models(model_type, model_count))
|
||||
return models_name_list
|
||||
|
||||
|
@ -298,22 +299,22 @@ class ModelManager(object):
|
|||
model_item = self.set_model_url(model_item)
|
||||
return model_item, model_full_name, model, md5hash
|
||||
|
||||
def ask_tos(self, model_full_path):
|
||||
@staticmethod
|
||||
def ask_tos(model_full_path):
|
||||
"""Ask the user to agree to the terms of service"""
|
||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||
if not os.path.exists(tos_path):
|
||||
print(" > You must agree to the terms of service to use this model.")
|
||||
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
|
||||
print(' | > "I have read, understood and agreed the Terms and Conditions." - [y/n]')
|
||||
answer = input(" | | > ")
|
||||
if answer.lower() == "y":
|
||||
with open(tos_path, "w") as f:
|
||||
f.write("I have read, understood ad agree the Terms and Conditions.")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
print(" > You must agree to the terms of service to use this model.")
|
||||
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
|
||||
print(' | > "I have read, understood and agreed to the Terms and Conditions." - [y/n]')
|
||||
answer = input(" | | > ")
|
||||
if answer.lower() == "y":
|
||||
with open(tos_path, "w", encoding="utf-8") as f:
|
||||
f.write("I have read, understood and agreed to the Terms and Conditions.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def tos_agreed(self, model_item, model_full_path):
|
||||
@staticmethod
|
||||
def tos_agreed(model_item, model_full_path):
|
||||
"""Check if the user has agreed to the terms of service"""
|
||||
if "tos_required" in model_item and model_item["tos_required"]:
|
||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||
|
@ -392,7 +393,7 @@ class ModelManager(object):
|
|||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts_v1" in model_name:
|
||||
if "xtts" in model_name:
|
||||
try:
|
||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||
except:
|
||||
|
@ -406,7 +407,7 @@ class ModelManager(object):
|
|||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
if (
|
||||
model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
|
||||
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
||||
): # TODO:This is stupid but don't care for now.
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
# update paths in the config.json
|
||||
|
@ -526,12 +527,12 @@ class ModelManager(object):
|
|||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
if progress_bar:
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
progress_bar.update(len(data))
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
with zipfile.ZipFile(temp_zip_name) as z:
|
||||
z.extractall(output_folder)
|
||||
|
@ -561,12 +562,12 @@ class ModelManager(object):
|
|||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
if progress_bar:
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_tar_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
progress_bar.update(len(data))
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
with tarfile.open(temp_tar_name) as t:
|
||||
t.extractall(output_folder)
|
||||
|
@ -597,10 +598,10 @@ class ModelManager(object):
|
|||
block_size = 1024 # 1 Kibibyte
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
if progress_bar:
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
progress_bar.update(len(data))
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue