use pytorch amp for mixed precision training for Tacotron

This commit is contained in:
erogol 2020-11-12 12:51:56 +01:00
parent 67e2b664e5
commit a7aefd5c50
2 changed files with 79 additions and 67 deletions

View File

@ -28,7 +28,8 @@ from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
remove_experiment_folder, set_init_dict,
set_amp_context)
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
@ -127,7 +128,7 @@ def format_data(data, speaker_mapping=None):
def train(model, criterion, optimizer, optimizer_st, scheduler,
ap, global_step, epoch, amp, speaker_mapping=None):
ap, global_step, epoch, scaler, scaler_st, speaker_mapping=None):
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train()
@ -152,65 +153,79 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# setup lr
if c.noam_schedule:
scheduler.step()
optimizer.zero_grad()
if optimizer_st:
optimizer_st.zero_grad()
# forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
decoder_backward_output = None
alignments_backward = None
with set_amp_context(c.mixed_precision):
# forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
decoder_backward_output = None
alignments_backward = None
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % model.decoder.r != 0:
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
else:
alignment_lengths = mel_lengths // model.decoder.r
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % model.decoder.r != 0:
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
else:
alignment_lengths = mel_lengths // model.decoder.r
# compute loss
loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output,
alignments, alignment_lengths, alignments_backward,
text_lengths)
# compute loss
loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output,
alignments, alignment_lengths, alignments_backward,
text_lengths)
# backward pass
if amp is not None:
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward()
# check nan loss
if torch.isnan(loss_dict['loss']).any():
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
# optimizer step
if c.mixed_precision:
# model optimizer step in mixed precision mode
scaler.scale(loss_dict['loss']).backward()
scaler.unscale_(optimizer)
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
scaler.step(optimizer)
scaler.update()
# stopnet optimizer step
if c.separate_stopnet:
scaler_st.scale( loss_dict['stopnet_loss']).backward()
scaler.unscale_(optimizer_st)
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
scaler_st.step(optimizer)
scaler_st.update()
else:
grad_norm_st = 0
else:
# main model optimizer step
loss_dict['loss'].backward()
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
optimizer.step()
optimizer, current_lr = adam_weight_decay(optimizer)
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
optimizer.step()
# stopnet optimizer step
if c.separate_stopnet:
loss_dict['stopnet_loss'].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
else:
grad_norm_st = 0
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
# backpass and check the grad norm for stop loss
if c.separate_stopnet:
loss_dict['stopnet_loss'].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st)
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params)
optimizer_st.step()
else:
grad_norm_st = 0
step_time = time.time() - start_time
epoch_time += step_time
@ -269,7 +284,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict['postnet_loss'],
amp_state_dict=amp.state_dict() if amp else None)
scaler=scaler.state_dict() if c.mixed_precision else None)
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
@ -505,6 +520,10 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
# scalers for mixed precision training
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None
params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet:
@ -514,14 +533,6 @@ def main(args): # pylint: disable=redefined-outer-name
else:
optimizer_st = None
if c.apex_amp_level == "O1":
# pylint: disable=import-outside-toplevel
from apex import amp
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
# setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
@ -543,9 +554,6 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(model_dict)
del model_dict
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
@ -588,14 +596,14 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
optimizer_st, scheduler, ap,
global_step, epoch, amp, speaker_mapping)
global_step, epoch, scaler, scaler_st, speaker_mapping)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss']
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None)
if __name__ == '__main__':
@ -647,8 +655,8 @@ if __name__ == '__main__':
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level == 'O1':
print(" > apex AMP level: ", c.apex_amp_level)
if c.mixed_precision:
print(" > Mixed precision mode is ON")
OUT_PATH = args.continue_path
if args.continue_path == '':

View File

@ -80,7 +80,7 @@ def format_test_data(data):
def train(model, criterion, optimizer,
scheduler, ap, global_step, epoch):
scheduler, scaler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
epoch_time = 0
@ -102,7 +102,6 @@ def train(model, criterion, optimizer,
model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -208,7 +207,8 @@ def train(model, criterion, optimizer,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict)
model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
end_time = time.time()
@ -336,6 +336,9 @@ def main(args): # pylint: disable=redefined-outer-name
# setup models
model = setup_generator(c)
# scaler for mixed_precision
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
# setup optimizers
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
@ -396,7 +399,7 @@ def main(args): # pylint: disable=redefined-outer-name
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step,
scheduler, scaler, ap, global_step,
epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch)
@ -413,7 +416,8 @@ def main(args): # pylint: disable=redefined-outer-name
global_step,
epoch,
OUT_PATH,
model_losses=eval_avg_loss_dict)
model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
if __name__ == '__main__':