diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index de6d95c5..4fd1f19c 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -87,6 +87,39 @@ def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False) return fig +def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the input characters. + + Args: + pitch (np.array): Pitch values. + chars (str): Characters to place to the x-axis. + + Shapes: + pitch: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = [c for c in chars] + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + def visualize( alignment, postnet_output,