Add model hash check

This commit is contained in:
WeberJulian 2023-10-20 09:25:51 -03:00
parent 5e7cba22bd
commit 478fe0b28d
2 changed files with 18 additions and 5 deletions

View File

@ -20,8 +20,10 @@
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json"
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5"
],
"model_hash": "10163afc541dc86801b33d1f3217b456",
"default_vocoder": null,
"commit": "82910a63",
"license": "CPML",

View File

@ -294,8 +294,9 @@ class ModelManager(object):
# get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item)
return model_item, model_full_name, model
return model_item, model_full_name, model, md5hash
def ask_tos(self, model_full_path):
"""Ask the user to agree to the terms of service"""
@ -358,8 +359,6 @@ class ModelManager(object):
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name):
"""Download model files given the full model name.
@ -375,10 +374,22 @@ class ModelManager(object):
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
if md5sum is not None:
md5sum_file = os.path.join(output_path, "hash.md5")
if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name: