mirror of https://github.com/coqui-ai/TTS.git
Made the tqdm `progress_bar` objects of static download methods a static class variable
This commit is contained in:
parent
6fef4f9067
commit
c3a746fee5
|
@ -26,7 +26,9 @@ LICENSE_URLS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
|
tqdm_progress = None
|
||||||
"""Manage TTS models defined in .models.json.
|
"""Manage TTS models defined in .models.json.
|
||||||
It provides an interface to list and download
|
It provides an interface to list and download
|
||||||
models defines in '.model.json'
|
models defines in '.model.json'
|
||||||
|
@ -109,7 +111,6 @@ class ModelManager(object):
|
||||||
def _list_for_model_type(self, model_type):
|
def _list_for_model_type(self, model_type):
|
||||||
models_name_list = []
|
models_name_list = []
|
||||||
model_count = 1
|
model_count = 1
|
||||||
model_type = "tts_models"
|
|
||||||
models_name_list.extend(self._list_models(model_type, model_count))
|
models_name_list.extend(self._list_models(model_type, model_count))
|
||||||
return models_name_list
|
return models_name_list
|
||||||
|
|
||||||
|
@ -298,22 +299,22 @@ class ModelManager(object):
|
||||||
model_item = self.set_model_url(model_item)
|
model_item = self.set_model_url(model_item)
|
||||||
return model_item, model_full_name, model, md5hash
|
return model_item, model_full_name, model, md5hash
|
||||||
|
|
||||||
def ask_tos(self, model_full_path):
|
@staticmethod
|
||||||
|
def ask_tos(model_full_path):
|
||||||
"""Ask the user to agree to the terms of service"""
|
"""Ask the user to agree to the terms of service"""
|
||||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||||
if not os.path.exists(tos_path):
|
|
||||||
print(" > You must agree to the terms of service to use this model.")
|
print(" > You must agree to the terms of service to use this model.")
|
||||||
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
|
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
|
||||||
print(' | > "I have read, understood and agreed the Terms and Conditions." - [y/n]')
|
print(' | > "I have read, understood and agreed to the Terms and Conditions." - [y/n]')
|
||||||
answer = input(" | | > ")
|
answer = input(" | | > ")
|
||||||
if answer.lower() == "y":
|
if answer.lower() == "y":
|
||||||
with open(tos_path, "w") as f:
|
with open(tos_path, "w", encoding="utf-8") as f:
|
||||||
f.write("I have read, understood ad agree the Terms and Conditions.")
|
f.write("I have read, understood and agreed to the Terms and Conditions.")
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def tos_agreed(self, model_item, model_full_path):
|
@staticmethod
|
||||||
|
def tos_agreed(model_item, model_full_path):
|
||||||
"""Check if the user has agreed to the terms of service"""
|
"""Check if the user has agreed to the terms of service"""
|
||||||
if "tos_required" in model_item and model_item["tos_required"]:
|
if "tos_required" in model_item and model_item["tos_required"]:
|
||||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||||
|
@ -392,7 +393,7 @@ class ModelManager(object):
|
||||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||||
# if the configs are different, redownload it
|
# if the configs are different, redownload it
|
||||||
# ToDo: we need a better way to handle it
|
# ToDo: we need a better way to handle it
|
||||||
if "xtts_v1" in model_name:
|
if "xtts" in model_name:
|
||||||
try:
|
try:
|
||||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||||
except:
|
except:
|
||||||
|
@ -406,7 +407,7 @@ class ModelManager(object):
|
||||||
output_model_path = output_path
|
output_model_path = output_path
|
||||||
output_config_path = None
|
output_config_path = None
|
||||||
if (
|
if (
|
||||||
model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
|
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
||||||
): # TODO:This is stupid but don't care for now.
|
): # TODO:This is stupid but don't care for now.
|
||||||
output_model_path, output_config_path = self._find_files(output_path)
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
# update paths in the config.json
|
# update paths in the config.json
|
||||||
|
@ -526,12 +527,12 @@ class ModelManager(object):
|
||||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||||
with open(temp_zip_name, "wb") as file:
|
with open(temp_zip_name, "wb") as file:
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
progress_bar.update(len(data))
|
ModelManager.tqdm_progress.update(len(data))
|
||||||
file.write(data)
|
file.write(data)
|
||||||
with zipfile.ZipFile(temp_zip_name) as z:
|
with zipfile.ZipFile(temp_zip_name) as z:
|
||||||
z.extractall(output_folder)
|
z.extractall(output_folder)
|
||||||
|
@ -561,12 +562,12 @@ class ModelManager(object):
|
||||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
|
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||||
with open(temp_tar_name, "wb") as file:
|
with open(temp_tar_name, "wb") as file:
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
progress_bar.update(len(data))
|
ModelManager.tqdm_progress.update(len(data))
|
||||||
file.write(data)
|
file.write(data)
|
||||||
with tarfile.open(temp_tar_name) as t:
|
with tarfile.open(temp_tar_name) as t:
|
||||||
t.extractall(output_folder)
|
t.extractall(output_folder)
|
||||||
|
@ -597,10 +598,10 @@ class ModelManager(object):
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
with open(temp_zip_name, "wb") as file:
|
with open(temp_zip_name, "wb") as file:
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
progress_bar.update(len(data))
|
ModelManager.tqdm_progress.update(len(data))
|
||||||
file.write(data)
|
file.write(data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in New Issue