Add model hash check

This commit is contained in:
WeberJulian 2023-10-20 09:25:51 -03:00
parent 5e7cba22bd
commit 478fe0b28d
2 changed files with 18 additions and 5 deletions

View File

@ -20,8 +20,10 @@
"hf_url": [ "hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json" "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5"
], ],
"model_hash": "10163afc541dc86801b33d1f3217b456",
"default_vocoder": null, "default_vocoder": null,
"commit": "82910a63", "commit": "82910a63",
"license": "CPML", "license": "CPML",

View File

@ -294,8 +294,9 @@ class ModelManager(object):
# get model from models.json # get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model] model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type model_item["model_type"] = model_type
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item) model_item = self.set_model_url(model_item)
return model_item, model_full_name, model return model_item, model_full_name, model, md5hash
def ask_tos(self, model_full_path): def ask_tos(self, model_full_path):
"""Ask the user to agree to the terms of service""" """Ask the user to agree to the terms of service"""
@ -358,8 +359,6 @@ class ModelManager(object):
if not config_local == config_remote: if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
@ -375,10 +374,22 @@ class ModelManager(object):
Args: Args:
model_name (str): model name as explained above. model_name (str): model name as explained above.
""" """
model_item, model_full_name, model = self._set_model_item(model_name) model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
# set the model specific output path # set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path): if os.path.exists(output_path):
if md5sum is not None:
md5sum_file = os.path.join(output_path, "hash.md5")
if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it # if the configs are different, redownload it
# ToDo: we need a better way to handle it # ToDo: we need a better way to handle it
if "xtts_v1" in model_name: if "xtts_v1" in model_name: