mirror of https://github.com/coqui-ai/TTS.git
Add model hash check
This commit is contained in:
parent
5e7cba22bd
commit
478fe0b28d
|
@ -20,8 +20,10 @@
|
|||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json"
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5"
|
||||
],
|
||||
"model_hash": "10163afc541dc86801b33d1f3217b456",
|
||||
"default_vocoder": null,
|
||||
"commit": "82910a63",
|
||||
"license": "CPML",
|
||||
|
|
|
@ -294,8 +294,9 @@ class ModelManager(object):
|
|||
# get model from models.json
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
|
||||
model_item = self.set_model_url(model_item)
|
||||
return model_item, model_full_name, model
|
||||
return model_item, model_full_name, model, md5hash
|
||||
|
||||
def ask_tos(self, model_full_path):
|
||||
"""Ask the user to agree to the terms of service"""
|
||||
|
@ -358,8 +359,6 @@ class ModelManager(object):
|
|||
if not config_local == config_remote:
|
||||
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
|
@ -375,10 +374,22 @@ class ModelManager(object):
|
|||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
"""
|
||||
model_item, model_full_name, model = self._set_model_item(model_name)
|
||||
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
if md5sum is not None:
|
||||
md5sum_file = os.path.join(output_path, "hash.md5")
|
||||
if os.path.isfile(md5sum_file):
|
||||
with open(md5sum_file, mode="r") as f:
|
||||
if not f.read() == md5sum:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts_v1" in model_name:
|
||||
|
|
Loading…
Reference in New Issue