fix handling scale_stats.npy for models downloaded from Github rls

This commit is contained in:
Eren Gölge 2021-03-10 16:40:30 +01:00
parent 4aba4e5b1e
commit d260fb03a2
1 changed files with 5 additions and 2 deletions

View File

@ -114,9 +114,12 @@ class ModelManager(object):
# download from gdrive
self._download_gdrive_file(model_item['model_file'], output_model_path)
self._download_gdrive_file(model_item['config_file'], output_config_path)
if self._check_dict_key(model_item, 'stats_file'):
if self._check_dict_key(model_item, 'stats_file'):
self._download_gdrive_file(model_item['stats_file'], output_stats_path)
# set the scale_path.npy file path in the model config.json
if self._check_dict_key(model_item, 'stats_file') or os.path.exists(os.path.join(output_path, 'scale_stats.npy')):
output_stats_path = os.path.join(output_path, 'scale_stats.npy')
self._download_gdrive_file(model_item['stats_file'], output_stats_path)
# set scale stats path in config.json
config_path = output_config_path
config = load_config(config_path)