mirror of https://github.com/coqui-ai/TTS.git
apex AMP O1 optimizations
This commit is contained in:
parent
299f1cafe6
commit
eb905aafd3
|
@ -117,7 +117,7 @@ def format_data(data):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
ap, global_step, epoch):
|
ap, global_step, epoch, amp):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
||||||
verbose=(epoch == 0))
|
verbose=(epoch == 0))
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -172,7 +172,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
# backward pass
|
# backward pass
|
||||||
loss_dict['loss'].backward()
|
loss_dict['loss'].backward()
|
||||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
if amp:
|
||||||
|
amp_opt_params = amp.master_params(optimizer)
|
||||||
|
else:
|
||||||
|
amp_opt_params = None
|
||||||
|
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
|
@ -183,7 +187,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
if c.separate_stopnet:
|
if c.separate_stopnet:
|
||||||
loss_dict['stopnet_loss'].backward()
|
loss_dict['stopnet_loss'].backward()
|
||||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
if amp:
|
||||||
|
amp_opt_params = amp.master_params(optimizer)
|
||||||
|
else:
|
||||||
|
amp_opt_params = None
|
||||||
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params)
|
||||||
optimizer_st.step()
|
optimizer_st.step()
|
||||||
else:
|
else:
|
||||||
grad_norm_st = 0
|
grad_norm_st = 0
|
||||||
|
@ -245,7 +253,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||||
optimizer_st=optimizer_st,
|
optimizer_st=optimizer_st,
|
||||||
model_loss=loss_dict['postnet_loss'])
|
model_loss=loss_dict['postnet_loss'],
|
||||||
|
amp_state_dict=amp.state_dict() if amp else None)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = postnet_output[0].data.cpu().numpy()
|
const_spec = postnet_output[0].data.cpu().numpy()
|
||||||
|
@ -497,6 +506,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
optimizer_st = None
|
optimizer_st = None
|
||||||
|
|
||||||
|
if c.apex_amp_level:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from apex import amp
|
||||||
|
model.cuda()
|
||||||
|
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||||
|
else:
|
||||||
|
amp = None
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
||||||
|
|
||||||
|
@ -515,6 +532,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
|
if amp and 'amp' in checkpoint:
|
||||||
|
amp.load_state_dict(checkpoint['amp'])
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group['lr'] = c.lr
|
group['lr'] = c.lr
|
||||||
print(" > Model restored from step %d" % checkpoint['step'],
|
print(" > Model restored from step %d" % checkpoint['step'],
|
||||||
|
@ -564,7 +585,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||||
OUT_PATH)
|
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -616,6 +637,9 @@ if __name__ == '__main__':
|
||||||
check_config(c)
|
check_config(c)
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
if c.apex_amp_level:
|
||||||
|
print(" > apex AMP level: ", c.apex_amp_level)
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
OUT_PATH = args.continue_path
|
||||||
if args.continue_path == '':
|
if args.continue_path == '':
|
||||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||||
|
|
|
@ -67,6 +67,7 @@
|
||||||
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||||
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
||||||
|
"apex_amp_level": "", // level of optimization with NVIDIA's apex feature for automatic mixed FP16/FP32 precision (AMP), NOTE: currently only O1 is supported, use "" (empty string) to deactivate
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
|
|
|
@ -6,6 +6,8 @@ import datetime
|
||||||
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
model.load_state_dict(state['model'])
|
model.load_state_dict(state['model'])
|
||||||
|
if amp and 'amp' in state:
|
||||||
|
amp.load_state_dict(state['amp'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
# set model stepsize
|
# set model stepsize
|
||||||
|
@ -14,7 +16,7 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||||
return model, state
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
||||||
new_state_dict = model.state_dict()
|
new_state_dict = model.state_dict()
|
||||||
state = {
|
state = {
|
||||||
'model': new_state_dict,
|
'model': new_state_dict,
|
||||||
|
@ -24,6 +26,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
'r': r
|
'r': r
|
||||||
}
|
}
|
||||||
|
if amp_state_dict:
|
||||||
|
state['amp'] = amp_state_dict
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
torch.save(state, output_path)
|
torch.save(state, output_path)
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,21 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||||
return use_cuda, num_gpus
|
return use_cuda, num_gpus
|
||||||
|
|
||||||
|
|
||||||
def check_update(model, grad_clip, ignore_stopnet=False):
|
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||||
r'''Check model gradient against unexpected jumps and failures'''
|
r'''Check model gradient against unexpected jumps and failures'''
|
||||||
skip_flag = False
|
skip_flag = False
|
||||||
if ignore_stopnet:
|
if ignore_stopnet:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
if not amp_opt_params:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
[param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||||
|
else:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
if not amp_opt_params:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||||
|
else:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||||
|
|
||||||
# compatibility with different torch versions
|
# compatibility with different torch versions
|
||||||
if isinstance(grad_norm, float):
|
if isinstance(grad_norm, float):
|
||||||
if np.isinf(grad_norm):
|
if np.isinf(grad_norm):
|
||||||
|
|
Loading…
Reference in New Issue