diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 6eb3abdf..17cba648 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -17,6 +17,8 @@ def plot_alignment(alignment, alignment_ = alignment.detach().cpu().numpy().squeeze() else: alignment_ = alignment + alignment_ = alignment_.astype( + np.float32) if alignment_.dtype == np.float16 else alignment_ fig, ax = plt.subplots(figsize=fig_size) im = ax.imshow(alignment_.T, aspect='auto',