Redownload XTTS with the local and remote config do not match

This commit is contained in:
Edresson Casanova 2023-10-06 17:16:30 -03:00
parent 0520697b5f
commit 4a6103fec9
1 changed files with 47 additions and 21 deletions

View File

@ -6,6 +6,7 @@ from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import fsspec
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -320,6 +321,31 @@ class ModelManager(object):
return False return False
return True return True
def check_if_files_size(self, model_name):
pass
def create_dir_and_download_model(self, model_name, model_item, output_path):
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
Model name is in the format Model name is in the format
@ -338,28 +364,28 @@ class ModelManager(object):
# set the model specific output path # set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path): if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.") # if the configs are different, redownload it
else: # ToDo: we need a better way to handle it
os.makedirs(output_path, exist_ok=True) if "xtts_v1" in model_name:
# handle TOS with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
if not self.tos_agreed(model_item, output_path): config_local = json.load(f)
if not self.ask_tos(output_path): remote_url = None
os.rmdir(output_path) for url in model_item["hf_url"]:
raise Exception(" [!] You must agree to the terms of service to use this model.") if "config.json" in url:
print(f" > Downloading model to {output_path}") remote_url = url
try: break
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path) with fsspec.open(remote_url, "r", encoding="utf-8") as f:
elif "github_rls_url" in model_item: config_remote = json.load(f)
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item: if not config_local == config_remote:
self._download_hf_model(model_item, output_path) print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
self.create_dir_and_download_model(model_name, model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
# find downloaded files # find downloaded files
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None