mirror of https://github.com/coqui-ai/TTS.git
use pytorch amp for mixed precision training for Tacotron
This commit is contained in:
parent
67e2b664e5
commit
a7aefd5c50
|
@ -28,7 +28,8 @@ from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
|||
init_distributed, reduce_tensor)
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
remove_experiment_folder, set_init_dict,
|
||||
set_amp_context)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
@ -127,7 +128,7 @@ def format_data(data, speaker_mapping=None):
|
|||
|
||||
|
||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||
ap, global_step, epoch, amp, speaker_mapping=None):
|
||||
ap, global_step, epoch, scaler, scaler_st, speaker_mapping=None):
|
||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||
model.train()
|
||||
|
@ -152,65 +153,79 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
# setup lr
|
||||
if c.noam_schedule:
|
||||
scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
if optimizer_st:
|
||||
optimizer_st.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
with set_amp_context(c.mixed_precision):
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||
linear_input, stop_tokens, stop_targets,
|
||||
mel_lengths, decoder_backward_output,
|
||||
alignments, alignment_lengths, alignments_backward,
|
||||
text_lengths)
|
||||
# compute loss
|
||||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||
linear_input, stop_tokens, stop_targets,
|
||||
mel_lengths, decoder_backward_output,
|
||||
alignments, alignment_lengths, alignments_backward,
|
||||
text_lengths)
|
||||
|
||||
# backward pass
|
||||
if amp is not None:
|
||||
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict['loss']).any():
|
||||
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
||||
|
||||
# optimizer step
|
||||
if c.mixed_precision:
|
||||
# model optimizer step in mixed precision mode
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
# stopnet optimizer step
|
||||
if c.separate_stopnet:
|
||||
scaler_st.scale( loss_dict['stopnet_loss']).backward()
|
||||
scaler.unscale_(optimizer_st)
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
scaler_st.step(optimizer)
|
||||
scaler_st.update()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
else:
|
||||
# main model optimizer step
|
||||
loss_dict['loss'].backward()
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||
optimizer.step()
|
||||
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
if amp:
|
||||
amp_opt_params = amp.master_params(optimizer)
|
||||
else:
|
||||
amp_opt_params = None
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
|
||||
optimizer.step()
|
||||
# stopnet optimizer step
|
||||
if c.separate_stopnet:
|
||||
loss_dict['stopnet_loss'].backward()
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
optimizer_st.step()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict['align_error'] = align_error
|
||||
|
||||
# backpass and check the grad norm for stop loss
|
||||
if c.separate_stopnet:
|
||||
loss_dict['stopnet_loss'].backward()
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
if amp:
|
||||
amp_opt_params = amp.master_params(optimizer)
|
||||
else:
|
||||
amp_opt_params = None
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params)
|
||||
optimizer_st.step()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
|
@ -269,7 +284,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict['postnet_loss'],
|
||||
amp_state_dict=amp.state_dict() if amp else None)
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
|
@ -505,6 +520,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
|
||||
|
||||
# scalers for mixed precision training
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None
|
||||
|
||||
params = set_weight_decay(model, c.wd)
|
||||
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
||||
if c.stopnet and c.separate_stopnet:
|
||||
|
@ -514,14 +533,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
else:
|
||||
optimizer_st = None
|
||||
|
||||
if c.apex_amp_level == "O1":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from apex import amp
|
||||
model.cuda()
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||
else:
|
||||
amp = None
|
||||
|
||||
# setup criterion
|
||||
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
||||
|
||||
|
@ -543,9 +554,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
if amp and 'amp' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp'])
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
|
@ -588,14 +596,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print("\n > Number of output frames:", model.decoder.r)
|
||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||
optimizer_st, scheduler, ap,
|
||||
global_step, epoch, amp, speaker_mapping)
|
||||
global_step, epoch, scaler, scaler_st, speaker_mapping)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
|
||||
OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -647,8 +655,8 @@ if __name__ == '__main__':
|
|||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.apex_amp_level == 'O1':
|
||||
print(" > apex AMP level: ", c.apex_amp_level)
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
|
|
|
@ -80,7 +80,7 @@ def format_test_data(data):
|
|||
|
||||
|
||||
def train(model, criterion, optimizer,
|
||||
scheduler, ap, global_step, epoch):
|
||||
scheduler, scaler, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
|
@ -102,7 +102,6 @@ def train(model, criterion, optimizer,
|
|||
model.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -208,7 +207,8 @@ def train(model, criterion, optimizer,
|
|||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict)
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
@ -336,6 +336,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# setup models
|
||||
model = setup_generator(c)
|
||||
|
||||
# scaler for mixed_precision
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
|
||||
# setup optimizers
|
||||
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
|
@ -396,7 +399,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
scheduler, scaler, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
||||
global_step, epoch)
|
||||
|
@ -413,7 +416,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=eval_avg_loss_dict)
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue