visualization updates wrt mean-var scaling

This commit is contained in:
erogol 2020-03-17 13:28:15 +01:00
parent d7cf34ca34
commit 3bbeb43f57
1 changed files with 4 additions and 4 deletions

View File

@ -32,22 +32,22 @@ def plot_spectrogram(linear_output, audio, fig_size=(16, 10)):
linear_output_ = linear_output.detach().cpu().numpy().squeeze() linear_output_ = linear_output.detach().cpu().numpy().squeeze()
else: else:
linear_output_ = linear_output linear_output_ = linear_output
spectrogram = audio._denormalize(linear_output_) # pylint: disable=protected-access spectrogram = audio._denormalize(linear_output_.T) # pylint: disable=protected-access
fig = plt.figure(figsize=fig_size) fig = plt.figure(figsize=fig_size)
plt.imshow(spectrogram.T, aspect="auto", origin="lower") plt.imshow(spectrogram, aspect="auto", origin="lower")
plt.colorbar() plt.colorbar()
plt.tight_layout() plt.tight_layout()
return fig return fig
def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None, output_path=None): def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None, output_path=None, figsize=[8, 24]):
if spectrogram is not None: if spectrogram is not None:
num_plot = 4 num_plot = 4
else: else:
num_plot = 3 num_plot = 3
label_fontsize = 16 label_fontsize = 16
fig = plt.figure(figsize=(8, 24)) fig = plt.figure(figsize=figsize)
plt.subplot(num_plot, 1, 1) plt.subplot(num_plot, 1, 1)
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)