diff --git a/TTS/.models.json b/TTS/.models.json index 841610d3..8e35893b 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -20,8 +20,10 @@ "hf_url": [ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json" + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5" ], + "model_hash": "10163afc541dc86801b33d1f3217b456", "default_vocoder": null, "commit": "82910a63", "license": "CPML", diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 1db6ae39..eef987ef 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -294,8 +294,9 @@ class ModelManager(object): # get model from models.json model_item = self.models_dict[model_type][lang][dataset][model] model_item["model_type"] = model_type + md5hash = model_item["model_hash"] if "model_hash" in model_item else None model_item = self.set_model_url(model_item) - return model_item, model_full_name, model + return model_item, model_full_name, model, md5hash def ask_tos(self, model_full_path): """Ask the user to agree to the terms of service""" @@ -358,8 +359,6 @@ class ModelManager(object): if not config_local == config_remote: print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") self.create_dir_and_download_model(model_name, model_item, output_path) - else: - print(f" > {model_name} is already downloaded.") def download_model(self, model_name): """Download model files given the full model name. @@ -375,10 +374,22 @@ class ModelManager(object): Args: model_name (str): model name as explained above. """ - model_item, model_full_name, model = self._set_model_item(model_name) + model_item, model_full_name, model, md5sum = self._set_model_item(model_name) # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): + if md5sum is not None: + md5sum_file = os.path.join(output_path, "hash.md5") + if os.path.isfile(md5sum_file): + with open(md5sum_file, mode="r") as f: + if not f.read() == md5sum: + print(f" > {model_name} has been updated, clearing model cache...") + self.create_dir_and_download_model(model_name, model_item, output_path) + else: + print(f" > {model_name} is already downloaded.") + else: + print(f" > {model_name} has been updated, clearing model cache...") + self.create_dir_and_download_model(model_name, model_item, output_path) # if the configs are different, redownload it # ToDo: we need a better way to handle it if "xtts_v1" in model_name: