diff --git a/train.py b/train.py index 51c73a9f..02f28c1d 100644 --- a/train.py +++ b/train.py @@ -222,15 +222,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, if args.rank == 0: # Plot Training Iter Stats # reduce TB load - if global_step % 10 == 0: + if global_step % c.tb_plot_step == 0: iter_stats = { - "loss_posnet": loss_dict['postnet_loss'].item(), - "loss_decoder": loss_dict['decoder_loss'].item(), "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time } + iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: diff --git a/utils/generic_utils.py b/utils/generic_utils.py index e3daf574..c50f8060 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -328,6 +328,7 @@ def check_config(c): # tensorboard _check_argument('print_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) _check_argument('save_step', c, restricted=True, val_type=int, min_val=1) _check_argument('checkpoint', c, restricted=True, val_type=bool) _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)