mirror of https://github.com/coqui-ai/TTS.git
resolve optimizer restore problem
This commit is contained in:
parent
5f7bf85b19
commit
0ac6284655
8
train.py
8
train.py
|
@ -351,8 +351,6 @@ def main(args):
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.r)
|
c.r)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
criterion = L1LossMasked().cuda()
|
criterion = L1LossMasked().cuda()
|
||||||
else:
|
else:
|
||||||
|
@ -361,7 +359,12 @@ def main(args):
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
for state in optimizer.state.values():
|
||||||
|
for k, v in state.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
state[k] = v.cuda()
|
||||||
print(" > Model restored from step %d" % checkpoint['step'])
|
print(" > Model restored from step %d" % checkpoint['step'])
|
||||||
start_epoch = checkpoint['step'] // len(train_loader)
|
start_epoch = checkpoint['step'] // len(train_loader)
|
||||||
best_loss = checkpoint['linear_loss']
|
best_loss = checkpoint['linear_loss']
|
||||||
|
@ -369,6 +372,7 @@ def main(args):
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
print(" > Starting a new training")
|
print(" > Starting a new training")
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
|
Loading…
Reference in New Issue