apex AMP O1 optimizations

This commit is contained in:
repodiac 2020-06-23 13:42:41 +02:00 committed by erogol
parent 299f1cafe6
commit eb905aafd3
4 changed files with 46 additions and 9 deletions

View File

@ -117,7 +117,7 @@ def format_data(data):
def train(model, criterion, optimizer, optimizer_st, scheduler,
ap, global_step, epoch):
ap, global_step, epoch, amp):
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
verbose=(epoch == 0))
model.train()
@ -172,7 +172,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# backward pass
loss_dict['loss'].backward()
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
optimizer.step()
# compute alignment error (the lower the better )
@ -183,7 +187,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
if c.separate_stopnet:
loss_dict['stopnet_loss'].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params)
optimizer_st.step()
else:
grad_norm_st = 0
@ -245,7 +253,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# save model
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict['postnet_loss'])
model_loss=loss_dict['postnet_loss'],
amp_state_dict=amp.state_dict() if amp else None)
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
@ -497,6 +506,14 @@ def main(args): # pylint: disable=redefined-outer-name
else:
optimizer_st = None
if c.apex_amp_level:
# pylint: disable=import-outside-toplevel
from apex import amp
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
# setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
@ -515,6 +532,10 @@ def main(args): # pylint: disable=redefined-outer-name
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model.load_state_dict(model_dict)
del model_dict
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
@ -564,7 +585,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH)
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
if __name__ == '__main__':
@ -616,6 +637,9 @@ if __name__ == '__main__':
check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level:
print(" > apex AMP level: ", c.apex_amp_level)
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)

View File

@ -67,6 +67,7 @@
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
"apex_amp_level": "", // level of optimization with NVIDIA's apex feature for automatic mixed FP16/FP32 precision (AMP), NOTE: currently only O1 is supported, use "" (empty string) to deactivate
// VALIDATION
"run_eval": true,

View File

@ -6,6 +6,8 @@ import datetime
def load_checkpoint(model, checkpoint_path, use_cuda=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model'])
if amp and 'amp' in state:
amp.load_state_dict(state['amp'])
if use_cuda:
model.cuda()
# set model stepsize
@ -14,7 +16,7 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
return model, state
def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
@ -24,6 +26,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r
}
if amp_state_dict:
state['amp'] = amp_state_dict
state.update(kwargs)
torch.save(state, output_path)

View File

@ -13,13 +13,21 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
return use_cuda, num_gpus
def check_update(model, grad_clip, ignore_stopnet=False):
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
if ignore_stopnet:
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
if not amp_opt_params:
grad_norm = torch.nn.utils.clip_grad_norm_(
[param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if not amp_opt_params:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
# compatibility with different torch versions
if isinstance(grad_norm, float):
if np.isinf(grad_norm):