From 5acd9e82bdf8c4e62eb11e3cd088e79b3ddd2ef8 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 16 Aug 2019 13:11:51 +0200 Subject: [PATCH] save model r value for checkpoints --- utils/generic_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 8a64dbae..1fa956ff 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -121,7 +121,8 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path, 'step': current_step, 'epoch': epoch, 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y") + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': model.decoder.r } torch.save(state, checkpoint_path) @@ -136,7 +137,8 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, 'step': current_step, 'epoch': epoch, 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y") + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': model.decoder.r } best_loss = model_loss bestmodel_path = 'best_model.pth.tar'