use pytorch amp for mixed precision training for Tacotron

This commit is contained in:
erogol 2020-11-12 12:51:56 +01:00
parent 67e2b664e5
commit a7aefd5c50
2 changed files with 79 additions and 67 deletions

View File

@ -28,7 +28,8 @@ from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor) init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch, create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict,
set_amp_context)
from TTS.utils.io import copy_config_file, load_config from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
@ -127,7 +128,7 @@ def format_data(data, speaker_mapping=None):
def train(model, criterion, optimizer, optimizer_st, scheduler, def train(model, criterion, optimizer, optimizer_st, scheduler,
ap, global_step, epoch, amp, speaker_mapping=None): ap, global_step, epoch, scaler, scaler_st, speaker_mapping=None):
data_loader = setup_loader(ap, model.decoder.r, is_val=False, data_loader = setup_loader(ap, model.decoder.r, is_val=False,
verbose=(epoch == 0), speaker_mapping=speaker_mapping) verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train() model.train()
@ -152,10 +153,12 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# setup lr # setup lr
if c.noam_schedule: if c.noam_schedule:
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
if optimizer_st: if optimizer_st:
optimizer_st.zero_grad() optimizer_st.zero_grad()
with set_amp_context(c.mixed_precision):
# forward pass model # forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency: if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
@ -179,38 +182,50 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
alignments, alignment_lengths, alignments_backward, alignments, alignment_lengths, alignments_backward,
text_lengths) text_lengths)
# backward pass # check nan loss
if amp is not None: if torch.isnan(loss_dict['loss']).any():
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss: raise RuntimeError(f'Detected NaN loss at step {global_step}.')
scaled_loss.backward()
else:
loss_dict['loss'].backward()
# optimizer step
if c.mixed_precision:
# model optimizer step in mixed precision mode
scaler.scale(loss_dict['loss']).backward()
scaler.unscale_(optimizer)
optimizer, current_lr = adam_weight_decay(optimizer) optimizer, current_lr = adam_weight_decay(optimizer)
if amp: grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
amp_opt_params = amp.master_params(optimizer) scaler.step(optimizer)
scaler.update()
# stopnet optimizer step
if c.separate_stopnet:
scaler_st.scale( loss_dict['stopnet_loss']).backward()
scaler.unscale_(optimizer_st)
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
scaler_st.step(optimizer)
scaler_st.update()
else: else:
amp_opt_params = None grad_norm_st = 0
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params) else:
# main model optimizer step
loss_dict['loss'].backward()
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
optimizer.step() optimizer.step()
# stopnet optimizer step
if c.separate_stopnet:
loss_dict['stopnet_loss'].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
else:
grad_norm_st = 0
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments) align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error loss_dict['align_error'] = align_error
# backpass and check the grad norm for stop loss
if c.separate_stopnet:
loss_dict['stopnet_loss'].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st)
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params)
optimizer_st.step()
else:
grad_norm_st = 0
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
@ -269,7 +284,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
optimizer_st=optimizer_st, optimizer_st=optimizer_st,
model_loss=loss_dict['postnet_loss'], model_loss=loss_dict['postnet_loss'],
amp_state_dict=amp.state_dict() if amp else None) scaler=scaler.state_dict() if c.mixed_precision else None)
# Diagnostic visualizations # Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy() const_spec = postnet_output[0].data.cpu().numpy()
@ -505,6 +520,10 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
# scalers for mixed precision training
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None
params = set_weight_decay(model, c.wd) params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0) optimizer = RAdam(params, lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet: if c.stopnet and c.separate_stopnet:
@ -514,14 +533,6 @@ def main(args): # pylint: disable=redefined-outer-name
else: else:
optimizer_st = None optimizer_st = None
if c.apex_amp_level == "O1":
# pylint: disable=import-outside-toplevel
from apex import amp
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
# setup criterion # setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
@ -543,9 +554,6 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(model_dict) model.load_state_dict(model_dict)
del model_dict del model_dict
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
for group in optimizer.param_groups: for group in optimizer.param_groups:
group['lr'] = c.lr group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'], print(" > Model restored from step %d" % checkpoint['step'],
@ -588,14 +596,14 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Number of output frames:", model.decoder.r) print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer, train_avg_loss_dict, global_step = train(model, criterion, optimizer,
optimizer_st, scheduler, ap, optimizer_st, scheduler, ap,
global_step, epoch, amp, speaker_mapping) global_step, epoch, scaler, scaler_st, speaker_mapping)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss'] target_loss = train_avg_loss_dict['avg_postnet_loss']
if c.run_eval: if c.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss'] target_loss = eval_avg_loss_dict['avg_postnet_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None) OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None)
if __name__ == '__main__': if __name__ == '__main__':
@ -647,8 +655,8 @@ if __name__ == '__main__':
check_config_tts(c) check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level == 'O1': if c.mixed_precision:
print(" > apex AMP level: ", c.apex_amp_level) print(" > Mixed precision mode is ON")
OUT_PATH = args.continue_path OUT_PATH = args.continue_path
if args.continue_path == '': if args.continue_path == '':

View File

@ -80,7 +80,7 @@ def format_test_data(data):
def train(model, criterion, optimizer, def train(model, criterion, optimizer,
scheduler, ap, global_step, epoch): scheduler, scaler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train() model.train()
epoch_time = 0 epoch_time = 0
@ -102,7 +102,6 @@ def train(model, criterion, optimizer,
model.compute_noise_level(noise_schedule['num_steps'], model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'], noise_schedule['min_val'],
noise_schedule['max_val']) noise_schedule['max_val'])
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
@ -208,7 +207,8 @@ def train(model, criterion, optimizer,
global_step, global_step,
epoch, epoch,
OUT_PATH, OUT_PATH,
model_losses=loss_dict) model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
end_time = time.time() end_time = time.time()
@ -336,6 +336,9 @@ def main(args): # pylint: disable=redefined-outer-name
# setup models # setup models
model = setup_generator(c) model = setup_generator(c)
# scaler for mixed_precision
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
# setup optimizers # setup optimizers
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
@ -396,7 +399,7 @@ def main(args): # pylint: disable=redefined-outer-name
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer, _, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step, scheduler, scaler, ap, global_step,
epoch) epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap, eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch) global_step, epoch)
@ -413,7 +416,8 @@ def main(args): # pylint: disable=redefined-outer-name
global_step, global_step,
epoch, epoch,
OUT_PATH, OUT_PATH,
model_losses=eval_avg_loss_dict) model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
if __name__ == '__main__': if __name__ == '__main__':