Find model files

This commit is contained in:
Eren Gölge 2022-03-22 17:47:34 +01:00
parent bfee55af2b
commit 160de0222d
1 changed files with 25 additions and 0 deletions

View File

@ -3,6 +3,7 @@ import json
import os import os
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from typing import Tuple
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
import requests import requests
@ -139,8 +140,32 @@ class ModelManager(object):
self._download_zip_file(model_item["github_rls_url"], output_path) self._download_zip_file(model_item["github_rls_url"], output_path)
# update paths in the config.json # update paths in the config.json
self._update_paths(output_path, output_config_path) self._update_paths(output_path, output_config_path)
# find downloaded files
output_model_path, output_config_path = self._find_files(output_path)
return output_model_path, output_config_path, model_item return output_model_path, output_config_path, model_item
def _find_files(self, output_path:str) -> Tuple[str, str]:
"""Find the model and config files in the output path
Args:
output_path (str): path to the model files
Returns:
Tuple[str, str]: path to the model file and config file
"""
model_file = None
config_file = None
for file_name in os.listdir(output_path):
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
model_file = os.path.join(output_path, file_name)
elif file_name == "config.json":
config_file = os.path.join(output_path, file_name)
if model_file is None:
raise ValueError(" [!] Model file not found in the output path")
if config_file is None:
raise ValueError(" [!] Config file not found in the output path")
return model_file, config_file
def _update_paths(self, output_path: str, config_path: str) -> None: def _update_paths(self, output_path: str, config_path: str) -> None:
"""Update paths for certain files in config.json after download. """Update paths for certain files in config.json after download.