diff --git a/train.py b/train.py index 45991015..89e21155 100644 --- a/train.py +++ b/train.py @@ -283,8 +283,11 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img), - "alignment_backward": plot_alignment(alignments_backward[0].data.cpu().numpy()) } + + if c.bidirectional_decoder: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy()) + tb_logger.tb_train_figures(global_step, figures) # Sample audio