diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 72c67df2..0da43f90 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -41,9 +41,9 @@ def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4 x_lengths = T max_idxs = x_lengths - segment_size + 1 assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." - ids_str = (torch.rand([B]).type_as(x) * max_idxs).long() - ret = segment(x, ids_str, segment_size) - return ret, ids_str + segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long() + ret = segment(x, segment_indices, segment_size) + return ret, segment_indices @dataclass diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 44732322..7101ed3d 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -49,6 +49,46 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): return fig +def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the spectrogram. + + Args: + pitch (np.array): Pitch values. + spectrogram (np.array): Spectrogram values. + + Shapes: + pitch: :math:`(T,)` + spec: :math:`(C, T)` + """ + + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + ax.imshow(spectrogram_, aspect="auto", origin="lower") + ax.set_xlabel("time") + ax.set_ylabel("spec_freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + def visualize( alignment, postnet_output,