use native amp for tacotron training

This commit is contained in:
erogol 2020-11-14 12:59:28 +01:00
parent 6cc464ead6
commit d8511efa8f
1 changed files with 7 additions and 3 deletions

View File

@ -539,12 +539,16 @@ def main(args): # pylint: disable=redefined-outer-name
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
print(" > Restoring Model.")
model.load_state_dict(checkpoint['model'])
# optimizer restore
# optimizer.load_state_dict(checkpoint['optimizer'])
print(" > Restoring Optimizer.")
optimizer.load_state_dict(checkpoint['optimizer'])
if "scaler" in checkpoint and c.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except KeyError:
print(" > Partial model initialization.")
model_dict = model.state_dict()