save model r value for checkpoints

This commit is contained in:
Eren Golge 2019-08-16 13:11:51 +02:00
parent 446cd6fa06
commit 5acd9e82bd
1 changed files with 4 additions and 2 deletions

View File

@ -121,7 +121,8 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
'step': current_step, 'step': current_step,
'epoch': epoch, 'epoch': epoch,
'linear_loss': model_loss, 'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y") 'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
} }
torch.save(state, checkpoint_path) torch.save(state, checkpoint_path)
@ -136,7 +137,8 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
'step': current_step, 'step': current_step,
'epoch': epoch, 'epoch': epoch,
'linear_loss': model_loss, 'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y") 'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
} }
best_loss = model_loss best_loss = model_loss
bestmodel_path = 'best_model.pth.tar' bestmodel_path = 'best_model.pth.tar'