mirror of https://github.com/coqui-ai/TTS.git
Revert "XTTS redownload if needed"
This commit is contained in:
parent
0f46757c47
commit
b3fecdcfda
|
@ -6,7 +6,6 @@ from pathlib import Path
|
|||
from shutil import copyfile, rmtree
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import fsspec
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -321,7 +320,26 @@ class ModelManager(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def create_dir_and_download_model(self, model_name, model_item, output_path):
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
'type/language/dataset/model'
|
||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||
|
||||
Every model must have the following files:
|
||||
- *.pth : pytorch model checkpoint file.
|
||||
- config.json : model config file.
|
||||
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||
|
||||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
"""
|
||||
model_item, model_full_name, model = self._set_model_item(model_name)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# handle TOS
|
||||
if not self.tos_agreed(model_item, output_path):
|
||||
|
@ -342,55 +360,6 @@ class ModelManager(object):
|
|||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
|
||||
def check_if_configs_are_equal(self, model_name, model_item, output_path):
|
||||
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
||||
config_local = json.load(f)
|
||||
remote_url = None
|
||||
for url in model_item["hf_url"]:
|
||||
if "config.json" in url:
|
||||
remote_url = url
|
||||
break
|
||||
|
||||
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
|
||||
config_remote = json.load(f)
|
||||
|
||||
if not config_local == config_remote:
|
||||
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
'type/language/dataset/model'
|
||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||
|
||||
Every model must have the following files:
|
||||
- *.pth : pytorch model checkpoint file.
|
||||
- config.json : model config file.
|
||||
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||
|
||||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
"""
|
||||
model_item, model_full_name, model = self._set_model_item(model_name)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts_v1" in model_name:
|
||||
try:
|
||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
# find downloaded files
|
||||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
|
|
Loading…
Reference in New Issue