enable backward compat for loading the best model

This commit is contained in:
Eren Gölge 2021-02-18 17:20:36 +00:00
parent d0454461de
commit 5b70c8ba4f
1 changed files with 20 additions and 10 deletions

View File

@ -86,24 +86,34 @@ def get_last_checkpoint(path):
last_models = {}
last_model_nums = {}
for key in ['checkpoint', 'best_model']:
last_model_num = 0
last_model_num = None
last_model = None
# pass all the checkpoint files and find
# the one with the largest model number suffix.
for file_name in file_names:
try:
model_num = int(re.search(
f"{key}_([0-9]+)", file_name).groups()[0])
if model_num > last_model_num:
match = re.search(f"{key}_([0-9]+)", file_name)
if match is not None:
model_num = int(match.groups()[0])
if model_num > last_model_num or last_model_num is None:
last_model_num = model_num
last_model = file_name
except AttributeError: # if there's no match in the filename
continue
last_models[key] = last_model
last_model_nums[key] = last_model_num
# if there is not checkpoint found above
# find the checkpoint with the latest
# modification date.
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = os.path.getctime(last_model)
if last_model is not None:
last_models[key] = last_model
last_model_nums[key] = last_model_num
# check what models were found
if not last_models:
raise ValueError(f"No models found in continue path {path}!")
elif 'checkpoint' not in last_models: # no checkpoint just best model
if 'checkpoint' not in last_models: # no checkpoint just best model
last_models['checkpoint'] = last_models['best_model']
elif 'best_model' not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case