From f7add3c8e5decc0e32a7080cb14b45756bd44fc2 Mon Sep 17 00:00:00 2001 From: Eren Date: Sat, 11 Aug 2018 16:53:09 +0200 Subject: [PATCH] tensorboardx plotting figures --- train.py | 16 ++++++++-------- utils/visual.py | 12 ++---------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/train.py b/train.py index b9d1616e..8f908062 100644 --- a/train.py +++ b/train.py @@ -148,12 +148,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, const_spec = plot_spectrogram(const_spec, ap) gt_spec = plot_spectrogram(gt_spec, ap) - tb.add_image('Visual/Reconstruction', const_spec, current_step) - tb.add_image('Visual/GroundTruth', gt_spec, current_step) + tb.add_figure('Visual/Reconstruction', const_spec, current_step) + tb.add_figure('Visual/GroundTruth', gt_spec, current_step) align_img = alignments[0].data.cpu().numpy() align_img = plot_alignment(align_img) - tb.add_image('Visual/Alignment', align_img, current_step) + tb.add_figure('Visual/Alignment', align_img, current_step) # Sample audio audio_signal = linear_output[0].data.cpu().numpy() @@ -275,9 +275,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): gt_spec = plot_spectrogram(gt_spec, ap) align_img = plot_alignment(align_img) - tb.add_image('ValVisual/Reconstruction', const_spec, current_step) - tb.add_image('ValVisual/GroundTruth', gt_spec, current_step) - tb.add_image('ValVisual/ValidationAlignment', align_img, + tb.add_figure('ValVisual/Reconstruction', const_spec, current_step) + tb.add_figure('ValVisual/GroundTruth', gt_spec, current_step) + tb.add_figure('ValVisual/ValidationAlignment', align_img, current_step) # Sample audio @@ -324,9 +324,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): align_img = alignments[0].data.cpu().numpy() linear_spec = plot_spectrogram(linear_spec, ap) align_img = plot_alignment(align_img) - tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec, + tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), linear_spec, current_step) - tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img, + tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img, current_step) return avg_linear_loss diff --git a/utils/visual.py b/utils/visual.py index 114e0ef2..8545ffe5 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -15,11 +15,7 @@ def plot_alignment(alignment, info=None): plt.xlabel(xlabel) plt.ylabel('Encoder timestep') plt.tight_layout() - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1]) - plt.close() - return data + return fig def plot_spectrogram(linear_output, audio): @@ -28,8 +24,4 @@ def plot_spectrogram(linear_output, audio): plt.imshow(spectrogram.T, aspect="auto", origin="lower") plt.colorbar() plt.tight_layout() - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1]) - plt.close() - return data + return fig