Redownload XTTS with the local and remote config do not match

This commit is contained in:
Edresson Casanova 2023-10-06 17:16:30 -03:00
parent 0520697b5f
commit 4a6103fec9
1 changed files with 47 additions and 21 deletions

View File

@ -6,6 +6,7 @@ from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import fsspec
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -320,26 +321,10 @@ class ModelManager(object):
return False return False
return True return True
def download_model(self, model_name): def check_if_files_size(self, model_name):
"""Download model files given the full model name. pass
Model name is in the format
'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files: def create_dir_and_download_model(self, model_name, model_item, output_path):
- *.pth : pytorch model checkpoint file.
- config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing.
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.")
else:
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
# handle TOS # handle TOS
if not self.tos_agreed(model_item, output_path): if not self.tos_agreed(model_item, output_path):
@ -360,6 +345,47 @@ class ModelManager(object):
rmtree(output_path) rmtree(output_path)
raise e raise e
self.print_model_license(model_item=model_item) self.print_model_license(model_item=model_item)
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files:
- *.pth : pytorch model checkpoint file.
- config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing.
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f)
remote_url = None
for url in model_item["hf_url"]:
if "config.json" in url:
remote_url = url
break
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
self.create_dir_and_download_model(model_name, model_item, output_path)
# find downloaded files # find downloaded files
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None