fix model manager

This commit is contained in:
Eren Gölge 2021-03-10 17:01:19 +01:00
parent d260fb03a2
commit e5bb317242
1 changed files with 2 additions and 3 deletions

View File

@ -104,7 +104,7 @@ class ModelManager(object):
else: else:
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
output_stats_path = None output_stats_path = os.path.join(output_path, 'scale_stats.npy')
# download files to the output path # download files to the output path
if self._check_dict_key(model_item, 'github_rls_url'): if self._check_dict_key(model_item, 'github_rls_url'):
# download from github release # download from github release
@ -118,8 +118,7 @@ class ModelManager(object):
self._download_gdrive_file(model_item['stats_file'], output_stats_path) self._download_gdrive_file(model_item['stats_file'], output_stats_path)
# set the scale_path.npy file path in the model config.json # set the scale_path.npy file path in the model config.json
if self._check_dict_key(model_item, 'stats_file') or os.path.exists(os.path.join(output_path, 'scale_stats.npy')): if self._check_dict_key(model_item, 'stats_file') or os.path.exists(output_stats_path):
output_stats_path = os.path.join(output_path, 'scale_stats.npy')
# set scale stats path in config.json # set scale stats path in config.json
config_path = output_config_path config_path = output_config_path
config = load_config(config_path) config = load_config(config_path)