some more tests for model saving

This commit is contained in:
Eren Golge 2018-02-21 07:21:44 -08:00
parent 6efb761139
commit d510f0e8aa
2 changed files with 2 additions and 3 deletions

View File

@ -218,8 +218,7 @@ def main(args):
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, linear_loss.data[0],
best_loss, OUT_PATH,
current_step, epoch)
OUT_PATH, current_step, epoch)
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()

View File

@ -60,7 +60,7 @@ def _trim_model_state_dict(state_dict):
return new_state_dict
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
def save_checkpoint(model, optimizer, model_loss, out_path,
current_step, epoch):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)