diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 033a5191..6eb3abdf 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -1,6 +1,8 @@ -import torch import librosa import matplotlib +import numpy as np +import torch + matplotlib.use('Agg') import matplotlib.pyplot as plt from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme @@ -43,6 +45,8 @@ def plot_spectrogram(spectrogram, spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T else: spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype( + np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ if ap is not None: spectrogram_ = ap._denormalize(spectrogram_) # pylint: disable=protected-access fig = plt.figure(figsize=fig_size)