Revert "XTTS redownload if needed"

This commit is contained in:
Gorkem 2023-10-07 01:04:10 +03:00 committed by GitHub
parent 0f46757c47
commit b3fecdcfda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 51 deletions

View File

@ -6,7 +6,6 @@ from pathlib import Path
from shutil import copyfile, rmtree
from typing import Dict, List, Tuple
import fsspec
import requests
from tqdm import tqdm
@ -321,7 +320,26 @@ class ModelManager(object):
return False
return True
def create_dir_and_download_model(self, model_name, model_item, output_path):
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files:
- *.pth : pytorch model checkpoint file.
- config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing.
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.")
else:
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
@ -342,55 +360,6 @@ class ModelManager(object):
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
def check_if_configs_are_equal(self, model_name, model_item, output_path):
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f)
remote_url = None
for url in model_item["hf_url"]:
if "config.json" in url:
remote_url = url
break
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files:
- *.pth : pytorch model checkpoint file.
- config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing.
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
try:
self.check_if_configs_are_equal(model_name, model_item, output_path)
except:
pass
else:
print(f" > {model_name} is already downloaded.")
else:
self.create_dir_and_download_model(model_name, model_item, output_path)
# find downloaded files
output_model_path = output_path
output_config_path = None