diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index b5c698f3..955eeb9b 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -6,6 +6,7 @@ from pathlib import Path from shutil import copyfile, rmtree from typing import Dict, List, Tuple +import fsspec import requests from tqdm import tqdm @@ -320,6 +321,46 @@ class ModelManager(object): return False return True + def create_dir_and_download_model(self, model_name, model_item, output_path): + os.makedirs(output_path, exist_ok=True) + # handle TOS + if not self.tos_agreed(model_item, output_path): + if not self.ask_tos(output_path): + os.rmdir(output_path) + raise Exception(" [!] You must agree to the terms of service to use this model.") + print(f" > Downloading model to {output_path}") + try: + if "fairseq" in model_name: + self.download_fairseq_model(model_name, output_path) + elif "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + + except requests.RequestException as e: + print(f" > Failed to download the model file to {output_path}") + rmtree(output_path) + raise e + self.print_model_license(model_item=model_item) + + def check_if_configs_are_equal(self, model_name, model_item, output_path): + with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: + config_local = json.load(f) + remote_url = None + for url in model_item["hf_url"]: + if "config.json" in url: + remote_url = url + break + + with fsspec.open(remote_url, "r", encoding="utf-8") as f: + config_remote = json.load(f) + + if not config_local == config_remote: + print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") + self.create_dir_and_download_model(model_name, model_item, output_path) + else: + print(f" > {model_name} is already downloaded.") + def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -338,28 +379,18 @@ class ModelManager(object): # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): - print(f" > {model_name} is already downloaded.") + # if the configs are different, redownload it + # ToDo: we need a better way to handle it + if "xtts_v1" in model_name: + try: + self.check_if_configs_are_equal(model_name, model_item, output_path) + except: + pass + else: + print(f" > {model_name} is already downloaded.") else: - os.makedirs(output_path, exist_ok=True) - # handle TOS - if not self.tos_agreed(model_item, output_path): - if not self.ask_tos(output_path): - os.rmdir(output_path) - raise Exception(" [!] You must agree to the terms of service to use this model.") - print(f" > Downloading model to {output_path}") - try: - if "fairseq" in model_name: - self.download_fairseq_model(model_name, output_path) - elif "github_rls_url" in model_item: - self._download_github_model(model_item, output_path) - elif "hf_url" in model_item: - self._download_hf_model(model_item, output_path) + self.create_dir_and_download_model(model_name, model_item, output_path) - except requests.RequestException as e: - print(f" > Failed to download the model file to {output_path}") - rmtree(output_path) - raise e - self.print_model_license(model_item=model_item) # find downloaded files output_model_path = output_path output_config_path = None