Merge pull request #3038 from coqui-ai/xtts_redonwload

XTTS redownload if needed
This commit is contained in:
Gorkem 2023-10-07 01:02:44 +03:00 committed by GitHub
commit 0f46757c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 51 additions and 20 deletions

View File

@ -6,6 +6,7 @@ from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import fsspec
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -320,6 +321,46 @@ class ModelManager(object):
return False return False
return True return True
def create_dir_and_download_model(self, model_name, model_item, output_path):
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
def check_if_configs_are_equal(self, model_name, model_item, output_path):
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f)
remote_url = None
for url in model_item["hf_url"]:
if "config.json" in url:
remote_url = url
break
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
Model name is in the format Model name is in the format
@ -338,28 +379,18 @@ class ModelManager(object):
# set the model specific output path # set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path): if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.") # if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
try:
self.check_if_configs_are_equal(model_name, model_item, output_path)
except:
pass
else:
print(f" > {model_name} is already downloaded.")
else: else:
os.makedirs(output_path, exist_ok=True) self.create_dir_and_download_model(model_name, model_item, output_path)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
# find downloaded files # find downloaded files
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None