Merge branch 'coqui-ai:dev' into dev

This commit is contained in:
Jindrich Matousek 2023-10-08 10:08:15 +02:00 committed by GitHub
commit aa7433fb5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 21 deletions

View File

@ -1 +1 @@
0.17.7 0.17.8

View File

@ -6,6 +6,7 @@ from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import fsspec
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -320,26 +321,7 @@ class ModelManager(object):
return False return False
return True return True
def download_model(self, model_name): def create_dir_and_download_model(self, model_name, model_item, output_path):
"""Download model files given the full model name.
Model name is in the format
'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files:
- *.pth : pytorch model checkpoint file.
- config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing.
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.")
else:
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
# handle TOS # handle TOS
if not self.tos_agreed(model_item, output_path): if not self.tos_agreed(model_item, output_path):
@ -360,6 +342,55 @@ class ModelManager(object):
rmtree(output_path) rmtree(output_path)
raise e raise e
self.print_model_license(model_item=model_item) self.print_model_license(model_item=model_item)
def check_if_configs_are_equal(self, model_name, model_item, output_path):
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f)
remote_url = None
for url in model_item["hf_url"]:
if "config.json" in url:
remote_url = url
break
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files:
- *.pth : pytorch model checkpoint file.
- config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing.
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
try:
self.check_if_configs_are_equal(model_name, model_item, output_path)
except:
pass
else:
print(f" > {model_name} is already downloaded.")
else:
self.create_dir_and_download_model(model_name, model_item, output_path)
# find downloaded files # find downloaded files
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None