diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index dd9f0e55..26c08150 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -28,7 +28,8 @@ from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) from TTS.utils.generic_utils import (KeepAverage, count_parameters, create_experiment_folder, get_git_branch, - remove_experiment_folder, set_init_dict) + remove_experiment_folder, set_init_dict, + set_amp_context) from TTS.utils.io import copy_config_file, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger @@ -127,7 +128,7 @@ def format_data(data, speaker_mapping=None): def train(model, criterion, optimizer, optimizer_st, scheduler, - ap, global_step, epoch, amp, speaker_mapping=None): + ap, global_step, epoch, scaler, scaler_st, speaker_mapping=None): data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() @@ -152,65 +153,79 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # setup lr if c.noam_schedule: scheduler.step() + optimizer.zero_grad() if optimizer_st: optimizer_st.zero_grad() - # forward pass model - if c.bidirectional_decoder or c.double_decoder_consistency: - decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) - else: - decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) - decoder_backward_output = None - alignments_backward = None + with set_amp_context(c.mixed_precision): + # forward pass model + if c.bidirectional_decoder or c.double_decoder_consistency: + decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + else: + decoder_output, postnet_output, alignments, stop_tokens = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + decoder_backward_output = None + alignments_backward = None - # set the [alignment] lengths wrt reduction factor for guided attention - if mel_lengths.max() % model.decoder.r != 0: - alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r - else: - alignment_lengths = mel_lengths // model.decoder.r + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % model.decoder.r != 0: + alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r + else: + alignment_lengths = mel_lengths // model.decoder.r - # compute loss - loss_dict = criterion(postnet_output, decoder_output, mel_input, - linear_input, stop_tokens, stop_targets, - mel_lengths, decoder_backward_output, - alignments, alignment_lengths, alignments_backward, - text_lengths) + # compute loss + loss_dict = criterion(postnet_output, decoder_output, mel_input, + linear_input, stop_tokens, stop_targets, + mel_lengths, decoder_backward_output, + alignments, alignment_lengths, alignments_backward, + text_lengths) - # backward pass - if amp is not None: - with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss: - scaled_loss.backward() + # check nan loss + if torch.isnan(loss_dict['loss']).any(): + raise RuntimeError(f'Detected NaN loss at step {global_step}.') + + # optimizer step + if c.mixed_precision: + # model optimizer step in mixed precision mode + scaler.scale(loss_dict['loss']).backward() + scaler.unscale_(optimizer) + optimizer, current_lr = adam_weight_decay(optimizer) + grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) + scaler.step(optimizer) + scaler.update() + + # stopnet optimizer step + if c.separate_stopnet: + scaler_st.scale( loss_dict['stopnet_loss']).backward() + scaler.unscale_(optimizer_st) + optimizer_st, _ = adam_weight_decay(optimizer_st) + grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) + scaler_st.step(optimizer) + scaler_st.update() + else: + grad_norm_st = 0 else: + # main model optimizer step loss_dict['loss'].backward() + optimizer, current_lr = adam_weight_decay(optimizer) + grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) + optimizer.step() - optimizer, current_lr = adam_weight_decay(optimizer) - if amp: - amp_opt_params = amp.master_params(optimizer) - else: - amp_opt_params = None - grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params) - optimizer.step() + # stopnet optimizer step + if c.separate_stopnet: + loss_dict['stopnet_loss'].backward() + optimizer_st, _ = adam_weight_decay(optimizer_st) + grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) + optimizer_st.step() + else: + grad_norm_st = 0 # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) loss_dict['align_error'] = align_error - # backpass and check the grad norm for stop loss - if c.separate_stopnet: - loss_dict['stopnet_loss'].backward() - optimizer_st, _ = adam_weight_decay(optimizer_st) - if amp: - amp_opt_params = amp.master_params(optimizer) - else: - amp_opt_params = None - grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params) - optimizer_st.step() - else: - grad_norm_st = 0 - step_time = time.time() - start_time epoch_time += step_time @@ -269,7 +284,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, optimizer_st=optimizer_st, model_loss=loss_dict['postnet_loss'], - amp_state_dict=amp.state_dict() if amp else None) + scaler=scaler.state_dict() if c.mixed_precision else None) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() @@ -505,6 +520,10 @@ def main(args): # pylint: disable=redefined-outer-name model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) + # scalers for mixed precision training + scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None + scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None + params = set_weight_decay(model, c.wd) optimizer = RAdam(params, lr=c.lr, weight_decay=0) if c.stopnet and c.separate_stopnet: @@ -514,14 +533,6 @@ def main(args): # pylint: disable=redefined-outer-name else: optimizer_st = None - if c.apex_amp_level == "O1": - # pylint: disable=import-outside-toplevel - from apex import amp - model.cuda() - model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level) - else: - amp = None - # setup criterion criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) @@ -543,9 +554,6 @@ def main(args): # pylint: disable=redefined-outer-name model.load_state_dict(model_dict) del model_dict - if amp and 'amp' in checkpoint: - amp.load_state_dict(checkpoint['amp']) - for group in optimizer.param_groups: group['lr'] = c.lr print(" > Model restored from step %d" % checkpoint['step'], @@ -588,14 +596,14 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Number of output frames:", model.decoder.r) train_avg_loss_dict, global_step = train(model, criterion, optimizer, optimizer_st, scheduler, ap, - global_step, epoch, amp, speaker_mapping) + global_step, epoch, scaler, scaler_st, speaker_mapping) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_postnet_loss'] if c.run_eval: target_loss = eval_avg_loss_dict['avg_postnet_loss'] best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, - OUT_PATH, amp_state_dict=amp.state_dict() if amp else None) + OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None) if __name__ == '__main__': @@ -647,8 +655,8 @@ if __name__ == '__main__': check_config_tts(c) _ = os.path.dirname(os.path.realpath(__file__)) - if c.apex_amp_level == 'O1': - print(" > apex AMP level: ", c.apex_amp_level) + if c.mixed_precision: + print(" > Mixed precision mode is ON") OUT_PATH = args.continue_path if args.continue_path == '': diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index 8bbfc55e..e5caa9e2 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -80,7 +80,7 @@ def format_test_data(data): def train(model, criterion, optimizer, - scheduler, ap, global_step, epoch): + scheduler, scaler, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 @@ -102,7 +102,6 @@ def train(model, criterion, optimizer, model.compute_noise_level(noise_schedule['num_steps'], noise_schedule['min_val'], noise_schedule['max_val']) - scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -208,7 +207,8 @@ def train(model, criterion, optimizer, global_step, epoch, OUT_PATH, - model_losses=loss_dict) + model_losses=loss_dict, + scaler=scaler.state_dict() if c.mixed_precision else None) end_time = time.time() @@ -336,6 +336,9 @@ def main(args): # pylint: disable=redefined-outer-name # setup models model = setup_generator(c) + # scaler for mixed_precision + scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None + # setup optimizers optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) @@ -396,7 +399,7 @@ def main(args): # pylint: disable=redefined-outer-name for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model, criterion, optimizer, - scheduler, ap, global_step, + scheduler, scaler, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) @@ -413,7 +416,8 @@ def main(args): # pylint: disable=redefined-outer-name global_step, epoch, OUT_PATH, - model_losses=eval_avg_loss_dict) + model_losses=eval_avg_loss_dict, + scaler=scaler.state_dict() if c.mixed_precision else None) if __name__ == '__main__':