diff --git a/utils/visual.py b/utils/visual.py index 87fbc8e4..3d95c2e3 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -27,14 +27,15 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None): return fig -def plot_spectrogram(linear_output, audio, fig_size=(16, 10)): - if isinstance(linear_output, torch.Tensor): - linear_output_ = linear_output.detach().cpu().numpy().squeeze() +def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10)): + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze() else: - linear_output_ = linear_output - spectrogram = audio._denormalize(linear_output_.T) # pylint: disable=protected-access + spectrogram_ = spectrogram + if ap is not None: + spectrogram_ = ap._denormalize(spectrogram_.T) # pylint: disable=protected-access fig = plt.figure(figsize=fig_size) - plt.imshow(spectrogram, aspect="auto", origin="lower") + plt.imshow(spectrogram_, aspect="auto", origin="lower") plt.colorbar() plt.tight_layout() return fig