mirror of https://github.com/coqui-ai/TTS.git
bug fix
This commit is contained in:
parent
1a061c4af5
commit
c24b57452d
5
train.py
5
train.py
|
@ -222,15 +222,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % 10 == 0:
|
||||
if global_step % c.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"loss_posnet": loss_dict['postnet_loss'].item(),
|
||||
"loss_decoder": loss_dict['decoder_loss'].item(),
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"grad_norm_st": grad_norm_st,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % c.save_step == 0:
|
||||
|
|
|
@ -327,6 +327,7 @@ def check_config(c):
|
|||
|
||||
# tensorboard
|
||||
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
|
|
Loading…
Reference in New Issue