diff --git a/train.py b/train.py index 831a106f..ac61cc9f 100644 --- a/train.py +++ b/train.py @@ -218,8 +218,7 @@ def main(args): if c.checkpoint: # save model save_checkpoint(model, optimizer, linear_loss.data[0], - best_loss, OUT_PATH, - current_step, epoch) + OUT_PATH, current_step, epoch) # Diagnostic visualizations const_spec = linear_output[0].data.cpu().numpy() diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 9221d611..d5b595f7 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -60,7 +60,7 @@ def _trim_model_state_dict(state_dict): return new_state_dict -def save_checkpoint(model, optimizer, model_loss, best_loss, out_path, +def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch): checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) checkpoint_path = os.path.join(out_path, checkpoint_path)