tensorboardx plotting figures

This commit is contained in:
Eren 2018-08-11 16:53:09 +02:00
parent b65f5bd1d0
commit f7add3c8e5
2 changed files with 10 additions and 18 deletions

View File

@ -148,12 +148,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
const_spec = plot_spectrogram(const_spec, ap)
gt_spec = plot_spectrogram(gt_spec, ap)
tb.add_image('Visual/Reconstruction', const_spec, current_step)
tb.add_image('Visual/GroundTruth', gt_spec, current_step)
tb.add_figure('Visual/Reconstruction', const_spec, current_step)
tb.add_figure('Visual/GroundTruth', gt_spec, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Visual/Alignment', align_img, current_step)
tb.add_figure('Visual/Alignment', align_img, current_step)
# Sample audio
audio_signal = linear_output[0].data.cpu().numpy()
@ -275,9 +275,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
gt_spec = plot_spectrogram(gt_spec, ap)
align_img = plot_alignment(align_img)
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
tb.add_image('ValVisual/ValidationAlignment', align_img,
tb.add_figure('ValVisual/Reconstruction', const_spec, current_step)
tb.add_figure('ValVisual/GroundTruth', gt_spec, current_step)
tb.add_figure('ValVisual/ValidationAlignment', align_img,
current_step)
# Sample audio
@ -324,9 +324,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
align_img = alignments[0].data.cpu().numpy()
linear_spec = plot_spectrogram(linear_spec, ap)
align_img = plot_alignment(align_img)
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
current_step)
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img,
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
current_step)
return avg_linear_loss

View File

@ -15,11 +15,7 @@ def plot_alignment(alignment, info=None):
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
plt.close()
return data
return fig
def plot_spectrogram(linear_output, audio):
@ -28,8 +24,4 @@ def plot_spectrogram(linear_output, audio):
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
plt.colorbar()
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
plt.close()
return data
return fig