mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'coqui-ai:dev' into dev
This commit is contained in:
commit
aa7433fb5a
|
@ -1 +1 @@
|
||||||
0.17.7
|
0.17.8
|
||||||
|
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||||
from shutil import copyfile, rmtree
|
from shutil import copyfile, rmtree
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -320,26 +321,7 @@ class ModelManager(object):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def download_model(self, model_name):
|
def create_dir_and_download_model(self, model_name, model_item, output_path):
|
||||||
"""Download model files given the full model name.
|
|
||||||
Model name is in the format
|
|
||||||
'type/language/dataset/model'
|
|
||||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
|
||||||
|
|
||||||
Every model must have the following files:
|
|
||||||
- *.pth : pytorch model checkpoint file.
|
|
||||||
- config.json : model config file.
|
|
||||||
- scale_stats.npy (if exist): scale values for preprocessing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): model name as explained above.
|
|
||||||
"""
|
|
||||||
model_item, model_full_name, model = self._set_model_item(model_name)
|
|
||||||
# set the model specific output path
|
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
|
||||||
if os.path.exists(output_path):
|
|
||||||
print(f" > {model_name} is already downloaded.")
|
|
||||||
else:
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
# handle TOS
|
# handle TOS
|
||||||
if not self.tos_agreed(model_item, output_path):
|
if not self.tos_agreed(model_item, output_path):
|
||||||
|
@ -360,6 +342,55 @@ class ModelManager(object):
|
||||||
rmtree(output_path)
|
rmtree(output_path)
|
||||||
raise e
|
raise e
|
||||||
self.print_model_license(model_item=model_item)
|
self.print_model_license(model_item=model_item)
|
||||||
|
|
||||||
|
def check_if_configs_are_equal(self, model_name, model_item, output_path):
|
||||||
|
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
||||||
|
config_local = json.load(f)
|
||||||
|
remote_url = None
|
||||||
|
for url in model_item["hf_url"]:
|
||||||
|
if "config.json" in url:
|
||||||
|
remote_url = url
|
||||||
|
break
|
||||||
|
|
||||||
|
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
|
||||||
|
config_remote = json.load(f)
|
||||||
|
|
||||||
|
if not config_local == config_remote:
|
||||||
|
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||||
|
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||||
|
else:
|
||||||
|
print(f" > {model_name} is already downloaded.")
|
||||||
|
|
||||||
|
def download_model(self, model_name):
|
||||||
|
"""Download model files given the full model name.
|
||||||
|
Model name is in the format
|
||||||
|
'type/language/dataset/model'
|
||||||
|
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||||
|
|
||||||
|
Every model must have the following files:
|
||||||
|
- *.pth : pytorch model checkpoint file.
|
||||||
|
- config.json : model config file.
|
||||||
|
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): model name as explained above.
|
||||||
|
"""
|
||||||
|
model_item, model_full_name, model = self._set_model_item(model_name)
|
||||||
|
# set the model specific output path
|
||||||
|
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||||
|
if os.path.exists(output_path):
|
||||||
|
# if the configs are different, redownload it
|
||||||
|
# ToDo: we need a better way to handle it
|
||||||
|
if "xtts_v1" in model_name:
|
||||||
|
try:
|
||||||
|
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print(f" > {model_name} is already downloaded.")
|
||||||
|
else:
|
||||||
|
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||||
|
|
||||||
# find downloaded files
|
# find downloaded files
|
||||||
output_model_path = output_path
|
output_model_path = output_path
|
||||||
output_config_path = None
|
output_config_path = None
|
||||||
|
|
Loading…
Reference in New Issue