Find model files

This commit is contained in:
Eren Gölge 2022-03-22 17:47:34 +01:00
parent bfee55af2b
commit 160de0222d
1 changed files with 25 additions and 0 deletions

View File

@ -3,6 +3,7 @@ import json
import os
import zipfile
from pathlib import Path
from typing import Tuple
from shutil import copyfile, rmtree
import requests
@ -139,8 +140,32 @@ class ModelManager(object):
self._download_zip_file(model_item["github_rls_url"], output_path)
# update paths in the config.json
self._update_paths(output_path, output_config_path)
# find downloaded files
output_model_path, output_config_path = self._find_files(output_path)
return output_model_path, output_config_path, model_item
def _find_files(self, output_path:str) -> Tuple[str, str]:
"""Find the model and config files in the output path
Args:
output_path (str): path to the model files
Returns:
Tuple[str, str]: path to the model file and config file
"""
model_file = None
config_file = None
for file_name in os.listdir(output_path):
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
model_file = os.path.join(output_path, file_name)
elif file_name == "config.json":
config_file = os.path.join(output_path, file_name)
if model_file is None:
raise ValueError(" [!] Model file not found in the output path")
if config_file is None:
raise ValueError(" [!] Config file not found in the output path")
return model_file, config_file
def _update_paths(self, output_path: str, config_path: str) -> None:
"""Update paths for certain files in config.json after download.