diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index 9c373fa7..8bbfc55e 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -221,7 +221,7 @@ def train(model, criterion, optimizer, if args.rank == 0: tb_logger.tb_train_epoch_stats(global_step, epoch_stats) # TODO: plot model stats - if c.tb_model_param_stats: + if c.tb_model_param_stats and args.rank == 0: tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step