diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 8a64dbae..1fa956ff 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -121,7 +121,8 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path, 'step': current_step, 'epoch': epoch, 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y") + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': model.decoder.r } torch.save(state, checkpoint_path) @@ -136,7 +137,8 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, 'step': current_step, 'epoch': epoch, 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y") + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': model.decoder.r } best_loss = model_loss bestmodel_path = 'best_model.pth.tar'