mirror of https://github.com/coqui-ai/TTS.git
Plot pitch over spectrogram
This commit is contained in:
parent
d847a68e42
commit
c1513ec4cd
|
@ -41,9 +41,9 @@ def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4
|
|||
x_lengths = T
|
||||
max_idxs = x_lengths - segment_size + 1
|
||||
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
|
||||
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
|
||||
ret = segment(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
|
||||
ret = segment(x, segment_indices, segment_size)
|
||||
return ret, segment_indices
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -49,6 +49,46 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
|||
return fig
|
||||
|
||||
|
||||
def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False):
|
||||
"""Plot pitch curves on top of the spectrogram.
|
||||
|
||||
Args:
|
||||
pitch (np.array): Pitch values.
|
||||
spectrogram (np.array): Spectrogram values.
|
||||
|
||||
Shapes:
|
||||
pitch: :math:`(T,)`
|
||||
spec: :math:`(C, T)`
|
||||
"""
|
||||
|
||||
if isinstance(spectrogram, torch.Tensor):
|
||||
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
|
||||
else:
|
||||
spectrogram_ = spectrogram.T
|
||||
spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
|
||||
if ap is not None:
|
||||
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
|
||||
|
||||
old_fig_size = plt.rcParams["figure.figsize"]
|
||||
if fig_size is not None:
|
||||
plt.rcParams["figure.figsize"] = fig_size
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.imshow(spectrogram_, aspect="auto", origin="lower")
|
||||
ax.set_xlabel("time")
|
||||
ax.set_ylabel("spec_freq")
|
||||
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(pitch, linewidth=5.0, color="red")
|
||||
ax2.set_ylabel("F0")
|
||||
|
||||
plt.rcParams["figure.figsize"] = old_fig_size
|
||||
if not output_fig:
|
||||
plt.close()
|
||||
return fig
|
||||
|
||||
|
||||
def visualize(
|
||||
alignment,
|
||||
postnet_output,
|
||||
|
|
Loading…
Reference in New Issue