save figures in visualize of set

This commit is contained in:
Eren Golge 2019-05-12 17:35:44 +02:00
parent 6331bccefc
commit 5e679f746d
1 changed files with 6 additions and 2 deletions

View File

@ -30,14 +30,14 @@ def plot_spectrogram(linear_output, audio):
return fig
def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None):
def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None, output_path=None):
if spectrogram is not None:
num_plot = 4
else:
num_plot = 3
label_fontsize = 16
plt.figure(figsize=(8, 24))
fig = plt.figure(figsize=(8, 24))
plt.subplot(num_plot, 1, 1)
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
@ -69,3 +69,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout()
plt.colorbar()
if output_path:
print(output_path)
fig.savefig(output_path)