From 0ac628465574ab412c81d3fed773fab8c7f1b3eb Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 28 Apr 2018 13:27:56 -0700 Subject: [PATCH] resolve optimizer restore problem --- train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index aa8e92ef..d98fc26d 100644 --- a/train.py +++ b/train.py @@ -351,8 +351,6 @@ def main(args): c.num_mels, c.r) - optimizer = optim.Adam(model.parameters(), lr=c.lr) - if use_cuda: criterion = L1LossMasked().cuda() else: @@ -361,7 +359,12 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) + optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer.load_state_dict(checkpoint['optimizer']) + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.cuda() print(" > Model restored from step %d" % checkpoint['step']) start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] @@ -369,6 +372,7 @@ def main(args): args.restore_step = checkpoint['step'] else: args.restore_step = 0 + optimizer = optim.Adam(model.parameters(), lr=c.lr) print(" > Starting a new training") if use_cuda: