mirror of https://github.com/coqui-ai/TTS.git
Download HF models
This commit is contained in:
parent
f59da4dba5
commit
5a31fad502
|
@ -245,6 +245,26 @@ class ModelManager(object):
|
||||||
else:
|
else:
|
||||||
print(" > Model's license - No license information available")
|
print(" > Model's license - No license information available")
|
||||||
|
|
||||||
|
def _download_github_model(self, model_item: Dict, output_path: str):
|
||||||
|
if isinstance(model_item["github_rls_url"], list):
|
||||||
|
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||||
|
else:
|
||||||
|
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||||
|
|
||||||
|
def _download_hf_model(self, model_item:Dict, output_path: str):
|
||||||
|
if isinstance(model_item["hf_url"], list):
|
||||||
|
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
|
||||||
|
else:
|
||||||
|
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
|
||||||
|
|
||||||
|
def set_model_url(self, model_item: Dict):
|
||||||
|
model_item["model_url"] = None
|
||||||
|
if "github_rls_url" in model_item:
|
||||||
|
model_item["model_url"] = model_item["github_rls_url"]
|
||||||
|
elif "hf_url" in model_item:
|
||||||
|
model_item["model_url"] = model_item["hf_url"]
|
||||||
|
return model_item
|
||||||
|
|
||||||
def download_model(self, model_name):
|
def download_model(self, model_name):
|
||||||
"""Download model files given the full model name.
|
"""Download model files given the full model name.
|
||||||
Model name is in the format
|
Model name is in the format
|
||||||
|
@ -264,6 +284,7 @@ class ModelManager(object):
|
||||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||||
model_item["model_type"] = model_type
|
model_item["model_type"] = model_type
|
||||||
|
model_item = self.set_model_url(model_item)
|
||||||
# set the model specific output path
|
# set the model specific output path
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
|
@ -271,16 +292,16 @@ class ModelManager(object):
|
||||||
else:
|
else:
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
print(f" > Downloading model to {output_path}")
|
print(f" > Downloading model to {output_path}")
|
||||||
# download from github release
|
if "github_rls_url" in model_item:
|
||||||
if isinstance(model_item["github_rls_url"], list):
|
self._download_github_model(model_item, output_path)
|
||||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
elif "hf_url" in model_item:
|
||||||
else:
|
self._download_hf_model(model_item, output_path)
|
||||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
|
||||||
self.print_model_license(model_item=model_item)
|
self.print_model_license(model_item=model_item)
|
||||||
# find downloaded files
|
# find downloaded files
|
||||||
output_model_path = output_path
|
output_model_path = output_path
|
||||||
output_config_path = None
|
output_config_path = None
|
||||||
if model != "tortoise-v2":
|
if model not in ["tortoise-v2", "bark"]: # TODO:This is stupid but don't care for now.
|
||||||
output_model_path, output_config_path = self._find_files(output_path)
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
# update paths in the config.json
|
# update paths in the config.json
|
||||||
self._update_paths(output_path, output_config_path)
|
self._update_paths(output_path, output_config_path)
|
||||||
|
|
Loading…
Reference in New Issue