mirror of https://github.com/coqui-ai/TTS.git
tensorboardx plotting figures
This commit is contained in:
parent
b65f5bd1d0
commit
f7add3c8e5
16
train.py
16
train.py
|
@ -148,12 +148,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
const_spec = plot_spectrogram(const_spec, ap)
|
||||
gt_spec = plot_spectrogram(gt_spec, ap)
|
||||
tb.add_image('Visual/Reconstruction', const_spec, current_step)
|
||||
tb.add_image('Visual/GroundTruth', gt_spec, current_step)
|
||||
tb.add_figure('Visual/Reconstruction', const_spec, current_step)
|
||||
tb.add_figure('Visual/GroundTruth', gt_spec, current_step)
|
||||
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
align_img = plot_alignment(align_img)
|
||||
tb.add_image('Visual/Alignment', align_img, current_step)
|
||||
tb.add_figure('Visual/Alignment', align_img, current_step)
|
||||
|
||||
# Sample audio
|
||||
audio_signal = linear_output[0].data.cpu().numpy()
|
||||
|
@ -275,9 +275,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
gt_spec = plot_spectrogram(gt_spec, ap)
|
||||
align_img = plot_alignment(align_img)
|
||||
|
||||
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
|
||||
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
|
||||
tb.add_image('ValVisual/ValidationAlignment', align_img,
|
||||
tb.add_figure('ValVisual/Reconstruction', const_spec, current_step)
|
||||
tb.add_figure('ValVisual/GroundTruth', gt_spec, current_step)
|
||||
tb.add_figure('ValVisual/ValidationAlignment', align_img,
|
||||
current_step)
|
||||
|
||||
# Sample audio
|
||||
|
@ -324,9 +324,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
align_img = alignments[0].data.cpu().numpy()
|
||||
linear_spec = plot_spectrogram(linear_spec, ap)
|
||||
align_img = plot_alignment(align_img)
|
||||
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
|
||||
tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
|
||||
current_step)
|
||||
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img,
|
||||
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
|
||||
current_step)
|
||||
return avg_linear_loss
|
||||
|
||||
|
|
|
@ -15,11 +15,7 @@ def plot_alignment(alignment, info=None):
|
|||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
|
||||
plt.close()
|
||||
return data
|
||||
return fig
|
||||
|
||||
|
||||
def plot_spectrogram(linear_output, audio):
|
||||
|
@ -28,8 +24,4 @@ def plot_spectrogram(linear_output, audio):
|
|||
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
|
||||
plt.colorbar()
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
|
||||
plt.close()
|
||||
return data
|
||||
return fig
|
||||
|
|
Loading…
Reference in New Issue