This commit is contained in:
erogol 2020-06-10 13:52:31 +02:00
parent 1a061c4af5
commit c24b57452d
2 changed files with 3 additions and 3 deletions

View File

@ -222,15 +222,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
if args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if global_step % 10 == 0:
if global_step % c.tb_plot_step == 0:
iter_stats = {
"loss_posnet": loss_dict['postnet_loss'].item(),
"loss_decoder": loss_dict['decoder_loss'].item(),
"lr": current_lr,
"grad_norm": grad_norm,
"grad_norm_st": grad_norm_st,
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0:

View File

@ -327,6 +327,7 @@ def check_config(c):
# tensorboard
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
_check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
_check_argument('checkpoint', c, restricted=True, val_type=bool)
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)