mirror of https://github.com/coqui-ai/TTS.git
Find model files
This commit is contained in:
parent
bfee55af2b
commit
160de0222d
|
@ -3,6 +3,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
from shutil import copyfile, rmtree
|
from shutil import copyfile, rmtree
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
@ -139,8 +140,32 @@ class ModelManager(object):
|
||||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||||
# update paths in the config.json
|
# update paths in the config.json
|
||||||
self._update_paths(output_path, output_config_path)
|
self._update_paths(output_path, output_config_path)
|
||||||
|
# find downloaded files
|
||||||
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
return output_model_path, output_config_path, model_item
|
return output_model_path, output_config_path, model_item
|
||||||
|
|
||||||
|
def _find_files(self, output_path:str) -> Tuple[str, str]:
|
||||||
|
"""Find the model and config files in the output path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str): path to the model files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, str]: path to the model file and config file
|
||||||
|
"""
|
||||||
|
model_file = None
|
||||||
|
config_file = None
|
||||||
|
for file_name in os.listdir(output_path):
|
||||||
|
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
|
||||||
|
model_file = os.path.join(output_path, file_name)
|
||||||
|
elif file_name == "config.json":
|
||||||
|
config_file = os.path.join(output_path, file_name)
|
||||||
|
if model_file is None:
|
||||||
|
raise ValueError(" [!] Model file not found in the output path")
|
||||||
|
if config_file is None:
|
||||||
|
raise ValueError(" [!] Config file not found in the output path")
|
||||||
|
return model_file, config_file
|
||||||
|
|
||||||
def _update_paths(self, output_path: str, config_path: str) -> None:
|
def _update_paths(self, output_path: str, config_path: str) -> None:
|
||||||
"""Update paths for certain files in config.json after download.
|
"""Update paths for certain files in config.json after download.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue