Let download model files

This commit is contained in:
Eren Gölge 2023-05-01 16:35:00 +02:00
parent e9fb28f19d
commit 8c739fd5f2
2 changed files with 19 additions and 35 deletions

View File

@ -220,15 +220,20 @@
"license": "apache 2.0", "license": "apache 2.0",
"contact": "adamfroghyar@gmail.com" "contact": "adamfroghyar@gmail.com"
} }
}, },
"multi-dataset":{ "multi-dataset":{
"tortoise-v2":{ "tortoise-v2":{
"description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts", "description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts",
"github_rls_url": ["https://coqui.gateway.scarf.sh/v0.14.0_models/tortoise_models.zip.part-aa", "github_rls_url": ["https://coqui.gateway.scarf.sh/v0.14.0_models/autoregressive.pth",
"https://coqui.gateway.scarf.sh/v0.14.0_models/tortoise_models.zip.part-ab", "https://coqui.gateway.scarf.sh/v0.14.0_models/clvp2.pth",
"https://coqui.gateway.scarf.sh/v0.14.0_models/tortoise_models.zip.part-ac", "https://coqui.gateway.scarf.sh/v0.14.0_models/cvvp.pth",
"https://coqui.gateway.scarf.sh/v0.14.0_models/tortoise_models.zip.part-ad"], "https://coqui.gateway.scarf.sh/v0.14.0_models/diffusion_decoder.pth",
"https://coqui.gateway.scarf.sh/v0.14.0_models/rlg_auto.pth",
"https://coqui.gateway.scarf.sh/v0.14.0_models/rlg_diffuser.pth",
"https://coqui.gateway.scarf.sh/v0.14.0_models/vocoder.pth",
"https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
],
"commit": "c1875f6", "commit": "c1875f6",
"default_vocoder": null, "default_vocoder": null,
"author": "neonbjb James Betker", "author": "neonbjb James Betker",

View File

@ -273,7 +273,7 @@ class ModelManager(object):
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
# download from github release # download from github release
if isinstance(model_item["github_rls_url"], list): if isinstance(model_item["github_rls_url"], list):
self._download_parted_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else: else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item) self.print_model_license(model_item=model_item)
@ -423,7 +423,7 @@ class ModelManager(object):
rmtree(os.path.join(output_folder, z.namelist()[0])) rmtree(os.path.join(output_folder, z.namelist()[0]))
@staticmethod @staticmethod
def _download_parted_zip_file(file_urls, output_folder, progress_bar): def _download_model_files(file_urls, output_folder, progress_bar):
"""Download the github releases""" """Download the github releases"""
for file_url in file_urls: for file_url in file_urls:
# download the file # download the file
@ -431,36 +431,15 @@ class ModelManager(object):
# extract the file # extract the file
bease_filename = file_url.split("/")[-1] bease_filename = file_url.split("/")[-1]
temp_zip_name = os.path.join(output_folder, bease_filename) temp_zip_name = os.path.join(output_folder, bease_filename)
file = open(temp_zip_name, "wb")
total_size_in_bytes = int(r.headers.get("content-length", 0)) total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
if progress_bar: with open(temp_zip_name, "wb") as file:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) if progress_bar:
for data in r.iter_content(block_size): progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
if progress_bar: for data in r.iter_content(block_size):
progress_bar.update(len(data)) if progress_bar:
file.write(data) progress_bar.update(len(data))
print("> Extracting All Models") file.write(data)
zip_file_name = os.path.join(output_folder, "tmp.zip")
with open(zip_file_name, "wb") as zip_file:
for part in file_urls:
zip_name = os.path.join(output_folder, part.split("/")[-1])
with open(zip_name, "rb") as zip_part:
zip_file.write(zip_part.read())
# remove parts
for part in file_urls:
temp_zip_name = os.path.join(output_folder, part.split("/")[-1])
os.remove(temp_zip_name)
with zipfile.ZipFile(zip_file_name, "r") as zip_ref:
zip_ref.extractall(path=output_folder)
os.remove(zip_file_name)
# move the files to the outer path
print(zip_ref.namelist())
for file_path in zip_ref.namelist()[1:]:
src_path = os.path.join(output_folder, file_path)
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
@staticmethod @staticmethod
def _check_dict_key(my_dict, key): def _check_dict_key(my_dict, key):