mirror of https://github.com/coqui-ai/TTS.git
apex AMP O1 optimizations
This commit is contained in:
parent
299f1cafe6
commit
eb905aafd3
|
@ -117,7 +117,7 @@ def format_data(data):
|
|||
|
||||
|
||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||
ap, global_step, epoch):
|
||||
ap, global_step, epoch, amp):
|
||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
||||
verbose=(epoch == 0))
|
||||
model.train()
|
||||
|
@ -172,7 +172,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
# backward pass
|
||||
loss_dict['loss'].backward()
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||
if amp:
|
||||
amp_opt_params = amp.master_params(optimizer)
|
||||
else:
|
||||
amp_opt_params = None
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
|
||||
optimizer.step()
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
|
@ -183,7 +187,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
if c.separate_stopnet:
|
||||
loss_dict['stopnet_loss'].backward()
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
if amp:
|
||||
amp_opt_params = amp.master_params(optimizer)
|
||||
else:
|
||||
amp_opt_params = None
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params)
|
||||
optimizer_st.step()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
|
@ -245,7 +253,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict['postnet_loss'])
|
||||
model_loss=loss_dict['postnet_loss'],
|
||||
amp_state_dict=amp.state_dict() if amp else None)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
|
@ -497,6 +506,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
else:
|
||||
optimizer_st = None
|
||||
|
||||
if c.apex_amp_level:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from apex import amp
|
||||
model.cuda()
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||
else:
|
||||
amp = None
|
||||
|
||||
# setup criterion
|
||||
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
||||
|
||||
|
@ -515,6 +532,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
if amp and 'amp' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp'])
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
|
@ -564,7 +585,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -616,6 +637,9 @@ if __name__ == '__main__':
|
|||
check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.apex_amp_level:
|
||||
print(" > apex AMP level: ", c.apex_amp_level)
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
|
|
@ -67,6 +67,7 @@
|
|||
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
||||
"apex_amp_level": "", // level of optimization with NVIDIA's apex feature for automatic mixed FP16/FP32 precision (AMP), NOTE: currently only O1 is supported, use "" (empty string) to deactivate
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
|
|
|
@ -6,6 +6,8 @@ import datetime
|
|||
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
model.load_state_dict(state['model'])
|
||||
if amp and 'amp' in state:
|
||||
amp.load_state_dict(state['amp'])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
# set model stepsize
|
||||
|
@ -14,7 +16,7 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
|||
return model, state
|
||||
|
||||
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
|
@ -24,6 +26,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
|||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': r
|
||||
}
|
||||
if amp_state_dict:
|
||||
state['amp'] = amp_state_dict
|
||||
state.update(kwargs)
|
||||
torch.save(state, output_path)
|
||||
|
||||
|
|
|
@ -13,13 +13,21 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
|||
return use_cuda, num_gpus
|
||||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False):
|
||||
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||
r'''Check model gradient against unexpected jumps and failures'''
|
||||
skip_flag = False
|
||||
if ignore_stopnet:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||
if not amp_opt_params:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
if not amp_opt_params:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||
|
||||
# compatibility with different torch versions
|
||||
if isinstance(grad_norm, float):
|
||||
if np.isinf(grad_norm):
|
||||
|
|
Loading…
Reference in New Issue