From 88bde77061ac093235d24cf0dfbef8133fc32358 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 12 May 2020 14:09:28 +0200 Subject: [PATCH] fix checkpointing --- utils/io.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/utils/io.py b/utils/io.py index 9161d6fd..f6378336 100644 --- a/utils/io.py +++ b/utils/io.py @@ -47,9 +47,7 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False): return model, state -def save_model(model, optimizer, current_step, epoch, r, output_folder, file_name, **kwargs): - checkpoint_path = os.path.join(output_folder, file_name) - +def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs): new_state_dict = model.state_dict() state = { 'model': new_state_dict, @@ -60,19 +58,21 @@ def save_model(model, optimizer, current_step, epoch, r, output_folder, file_nam 'r': model.decoder.r } state.update(kwargs) - torch.save(state, checkpoint_path) + torch.save(state, output_path) def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): - print(" > CHECKPOINT : {}".format(checkpoint_path)) file_name = 'checkpoint_{}.pth.tar'.format(current_step) - save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, **kwargs) + checkpoint_path = os.path.join(output_folder, file_name) + print(" > CHECKPOINT : {}".format(checkpoint_path)) + save_model(model, optimizer, current_step, epoch ,r, checkpoint_path, **kwargs) def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs): if target_loss < best_loss: - print(" > BEST MODEL : {}".format(checkpoint_path)) file_name = 'best_model.pth.tar' - save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, model_loss=target_loss) + checkpoint_path = os.path.join(output_folder, file_name) + print(" > BEST MODEL : {}".format(checkpoint_path)) + save_model(model, optimizer, current_step, epoch ,r, checkpoint_path, model_loss=target_loss) best_loss = target_loss return best_loss \ No newline at end of file