diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 4afa5d8f..0c01a79e 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -117,7 +117,7 @@ def format_data(data): def train(model, criterion, optimizer, optimizer_st, scheduler, - ap, global_step, epoch): + ap, global_step, epoch, amp): data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0)) model.train() @@ -172,7 +172,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # backward pass loss_dict['loss'].backward() optimizer, current_lr = adam_weight_decay(optimizer) - grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) + if amp: + amp_opt_params = amp.master_params(optimizer) + else: + amp_opt_params = None + grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params) optimizer.step() # compute alignment error (the lower the better ) @@ -183,7 +187,11 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, if c.separate_stopnet: loss_dict['stopnet_loss'].backward() optimizer_st, _ = adam_weight_decay(optimizer_st) - grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) + if amp: + amp_opt_params = amp.master_params(optimizer) + else: + amp_opt_params = None + grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params) optimizer_st.step() else: grad_norm_st = 0 @@ -245,7 +253,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # save model save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, optimizer_st=optimizer_st, - model_loss=loss_dict['postnet_loss']) + model_loss=loss_dict['postnet_loss'], + amp_state_dict=amp.state_dict() if amp else None) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() @@ -497,6 +506,14 @@ def main(args): # pylint: disable=redefined-outer-name else: optimizer_st = None + if c.apex_amp_level: + # pylint: disable=import-outside-toplevel + from apex import amp + model.cuda() + model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level) + else: + amp = None + # setup criterion criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) @@ -515,6 +532,10 @@ def main(args): # pylint: disable=redefined-outer-name model_dict = set_init_dict(model_dict, checkpoint['model'], c) model.load_state_dict(model_dict) del model_dict + + if amp and 'amp' in checkpoint: + amp.load_state_dict(checkpoint['amp']) + for group in optimizer.param_groups: group['lr'] = c.lr print(" > Model restored from step %d" % checkpoint['step'], @@ -564,7 +585,7 @@ def main(args): # pylint: disable=redefined-outer-name if c.run_eval: target_loss = eval_avg_loss_dict['avg_postnet_loss'] best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, - OUT_PATH) + OUT_PATH, amp_state_dict=amp.state_dict() if amp else None) if __name__ == '__main__': @@ -616,6 +637,9 @@ if __name__ == '__main__': check_config(c) _ = os.path.dirname(os.path.realpath(__file__)) + if c.apex_amp_level: + print(" > apex AMP level: ", c.apex_amp_level) + OUT_PATH = args.continue_path if args.continue_path == '': OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index 23868a33..d98802f7 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -67,6 +67,7 @@ "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "loss_masking": true, // enable / disable loss masking against the sequence padding. "ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled. + "apex_amp_level": "", // level of optimization with NVIDIA's apex feature for automatic mixed FP16/FP32 precision (AMP), NOTE: currently only O1 is supported, use "" (empty string) to deactivate // VALIDATION "run_eval": true, diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index db95b0c5..f01c427c 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -6,6 +6,8 @@ import datetime def load_checkpoint(model, checkpoint_path, use_cuda=False): state = torch.load(checkpoint_path, map_location=torch.device('cpu')) model.load_state_dict(state['model']) + if amp and 'amp' in state: + amp.load_state_dict(state['amp']) if use_cuda: model.cuda() # set model stepsize @@ -14,7 +16,7 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False): return model, state -def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs): +def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs): new_state_dict = model.state_dict() state = { 'model': new_state_dict, @@ -24,6 +26,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs): 'date': datetime.date.today().strftime("%B %d, %Y"), 'r': r } + if amp_state_dict: + state['amp'] = amp_state_dict state.update(kwargs) torch.save(state, output_path) diff --git a/TTS/utils/training.py b/TTS/utils/training.py index 9046f9e0..8166562c 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -13,13 +13,21 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark): return use_cuda, num_gpus -def check_update(model, grad_clip, ignore_stopnet=False): +def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): r'''Check model gradient against unexpected jumps and failures''' skip_flag = False if ignore_stopnet: - grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_( + [param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) else: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + # compatibility with different torch versions if isinstance(grad_norm, float): if np.isinf(grad_norm):