mirror of https://github.com/coqui-ai/TTS.git
use pytorch amp for mixed precision training for Tacotron
This commit is contained in:
parent
67e2b664e5
commit
a7aefd5c50
|
@ -28,7 +28,8 @@ from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
init_distributed, reduce_tensor)
|
init_distributed, reduce_tensor)
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict,
|
||||||
|
set_amp_context)
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
|
@ -127,7 +128,7 @@ def format_data(data, speaker_mapping=None):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
ap, global_step, epoch, amp, speaker_mapping=None):
|
ap, global_step, epoch, scaler, scaler_st, speaker_mapping=None):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
||||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -152,10 +153,12 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
# setup lr
|
# setup lr
|
||||||
if c.noam_schedule:
|
if c.noam_schedule:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if optimizer_st:
|
if optimizer_st:
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
|
|
||||||
|
with set_amp_context(c.mixed_precision):
|
||||||
# forward pass model
|
# forward pass model
|
||||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||||
|
@ -179,38 +182,50 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
alignments, alignment_lengths, alignments_backward,
|
alignments, alignment_lengths, alignments_backward,
|
||||||
text_lengths)
|
text_lengths)
|
||||||
|
|
||||||
# backward pass
|
# check nan loss
|
||||||
if amp is not None:
|
if torch.isnan(loss_dict['loss']).any():
|
||||||
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
||||||
scaled_loss.backward()
|
|
||||||
else:
|
|
||||||
loss_dict['loss'].backward()
|
|
||||||
|
|
||||||
|
# optimizer step
|
||||||
|
if c.mixed_precision:
|
||||||
|
# model optimizer step in mixed precision mode
|
||||||
|
scaler.scale(loss_dict['loss']).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||||
if amp:
|
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||||
amp_opt_params = amp.master_params(optimizer)
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
# stopnet optimizer step
|
||||||
|
if c.separate_stopnet:
|
||||||
|
scaler_st.scale( loss_dict['stopnet_loss']).backward()
|
||||||
|
scaler.unscale_(optimizer_st)
|
||||||
|
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||||
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||||
|
scaler_st.step(optimizer)
|
||||||
|
scaler_st.update()
|
||||||
else:
|
else:
|
||||||
amp_opt_params = None
|
grad_norm_st = 0
|
||||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
|
else:
|
||||||
|
# main model optimizer step
|
||||||
|
loss_dict['loss'].backward()
|
||||||
|
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||||
|
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# stopnet optimizer step
|
||||||
|
if c.separate_stopnet:
|
||||||
|
loss_dict['stopnet_loss'].backward()
|
||||||
|
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||||
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||||
|
optimizer_st.step()
|
||||||
|
else:
|
||||||
|
grad_norm_st = 0
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
# compute alignment error (the lower the better )
|
||||||
align_error = 1 - alignment_diagonal_score(alignments)
|
align_error = 1 - alignment_diagonal_score(alignments)
|
||||||
loss_dict['align_error'] = align_error
|
loss_dict['align_error'] = align_error
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
|
||||||
if c.separate_stopnet:
|
|
||||||
loss_dict['stopnet_loss'].backward()
|
|
||||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
|
||||||
if amp:
|
|
||||||
amp_opt_params = amp.master_params(optimizer)
|
|
||||||
else:
|
|
||||||
amp_opt_params = None
|
|
||||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params)
|
|
||||||
optimizer_st.step()
|
|
||||||
else:
|
|
||||||
grad_norm_st = 0
|
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
|
@ -269,7 +284,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||||
optimizer_st=optimizer_st,
|
optimizer_st=optimizer_st,
|
||||||
model_loss=loss_dict['postnet_loss'],
|
model_loss=loss_dict['postnet_loss'],
|
||||||
amp_state_dict=amp.state_dict() if amp else None)
|
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = postnet_output[0].data.cpu().numpy()
|
const_spec = postnet_output[0].data.cpu().numpy()
|
||||||
|
@ -505,6 +520,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
|
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
|
||||||
|
|
||||||
|
# scalers for mixed precision training
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||||
|
scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None
|
||||||
|
|
||||||
params = set_weight_decay(model, c.wd)
|
params = set_weight_decay(model, c.wd)
|
||||||
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
||||||
if c.stopnet and c.separate_stopnet:
|
if c.stopnet and c.separate_stopnet:
|
||||||
|
@ -514,14 +533,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
optimizer_st = None
|
optimizer_st = None
|
||||||
|
|
||||||
if c.apex_amp_level == "O1":
|
|
||||||
# pylint: disable=import-outside-toplevel
|
|
||||||
from apex import amp
|
|
||||||
model.cuda()
|
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
|
||||||
else:
|
|
||||||
amp = None
|
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
||||||
|
|
||||||
|
@ -543,9 +554,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
if amp and 'amp' in checkpoint:
|
|
||||||
amp.load_state_dict(checkpoint['amp'])
|
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group['lr'] = c.lr
|
group['lr'] = c.lr
|
||||||
print(" > Model restored from step %d" % checkpoint['step'],
|
print(" > Model restored from step %d" % checkpoint['step'],
|
||||||
|
@ -588,14 +596,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
print("\n > Number of output frames:", model.decoder.r)
|
print("\n > Number of output frames:", model.decoder.r)
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||||
optimizer_st, scheduler, ap,
|
optimizer_st, scheduler, ap,
|
||||||
global_step, epoch, amp, speaker_mapping)
|
global_step, epoch, scaler, scaler_st, speaker_mapping)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||||
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
|
OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -647,8 +655,8 @@ if __name__ == '__main__':
|
||||||
check_config_tts(c)
|
check_config_tts(c)
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
if c.apex_amp_level == 'O1':
|
if c.mixed_precision:
|
||||||
print(" > apex AMP level: ", c.apex_amp_level)
|
print(" > Mixed precision mode is ON")
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
OUT_PATH = args.continue_path
|
||||||
if args.continue_path == '':
|
if args.continue_path == '':
|
||||||
|
|
|
@ -80,7 +80,7 @@ def format_test_data(data):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer,
|
def train(model, criterion, optimizer,
|
||||||
scheduler, ap, global_step, epoch):
|
scheduler, scaler, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
@ -102,7 +102,6 @@ def train(model, criterion, optimizer,
|
||||||
model.compute_noise_level(noise_schedule['num_steps'],
|
model.compute_noise_level(noise_schedule['num_steps'],
|
||||||
noise_schedule['min_val'],
|
noise_schedule['min_val'],
|
||||||
noise_schedule['max_val'])
|
noise_schedule['max_val'])
|
||||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -208,7 +207,8 @@ def train(model, criterion, optimizer,
|
||||||
global_step,
|
global_step,
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
model_losses=loss_dict)
|
model_losses=loss_dict,
|
||||||
|
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
@ -336,6 +336,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# setup models
|
# setup models
|
||||||
model = setup_generator(c)
|
model = setup_generator(c)
|
||||||
|
|
||||||
|
# scaler for mixed_precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||||
|
|
||||||
# setup optimizers
|
# setup optimizers
|
||||||
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
||||||
|
@ -396,7 +399,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
_, global_step = train(model, criterion, optimizer,
|
_, global_step = train(model, criterion, optimizer,
|
||||||
scheduler, ap, global_step,
|
scheduler, scaler, ap, global_step,
|
||||||
epoch)
|
epoch)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
||||||
global_step, epoch)
|
global_step, epoch)
|
||||||
|
@ -413,7 +416,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
global_step,
|
global_step,
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
model_losses=eval_avg_loss_dict)
|
model_losses=eval_avg_loss_dict,
|
||||||
|
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue