coqui-tts/TTS/vocoder/utils/generic_utils.py

90 lines
3.1 KiB
Python

from typing import Dict
import numpy as np
import torch
from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
from TTS.utils.audio.processor import AudioProcessor
def interpolate_vocoder_input(scale_factor, spec):
"""Interpolate spectrogram by the scale factor.
It is mainly used to match the sampling rates of
the tts and vocoder models.
Args:
scale_factor (float): scale factor to interpolate the spectrogram
spec (np.array): spectrogram to be interpolated
Returns:
torch.tensor: interpolated spectrogram.
"""
print(" > before interpolation :", spec.shape)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
spec = torch.nn.functional.interpolate(
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
).squeeze(0)
print(" > after interpolation :", spec.shape)
return spec
def plot_results(
y_hat: torch.tensor,
y: torch.tensor,
ap: AudioProcessor = None,
audio_config: "Coqpit" = None,
name_prefix: str = None,
) -> Dict:
"""Plot the predicted and the real waveform and their spectrograms.
Args:
y_hat (torch.tensor): Predicted waveform.
y (torch.tensor): Real waveform.
ap (AudioProcessor): Audio processor used to process the waveform. Defaults to None.
audio_config (Coqpit): Audio configuration. Only used when ```ap``` is None. Defaults to None.
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
Returns:
Dict: output figures keyed by the name of the figures.
""" """Plot vocoder model results"""
if name_prefix is None:
name_prefix = ""
# select an instance from batch
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
y = y[0].squeeze().detach().cpu().numpy()
if ap is not None:
spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T
elif audio_config is not None:
mel_basis = build_mel_basis(**audio_config)
spec_fake = wav_to_mel(y=y_hat, mel_basis=mel_basis, **audio_config).T
spec_real = wav_to_mel(y=y, mel_basis=mel_basis, **audio_config).T
spec_fake = amp_to_db(x=spec_fake, gain=1.0, base=10.0)
spec_real = amp_to_db(x=spec_real, gain=1.0, base=10.0)
else:
raise ValueError(" [!] Either `ap` or `audio_config` must be provided.")
spec_diff = np.abs(spec_fake - spec_real)
# plot figure and save it
fig_wave = plt.figure()
plt.subplot(2, 1, 1)
plt.plot(y)
plt.title("groundtruth speech")
plt.subplot(2, 1, 2)
plt.plot(y_hat)
plt.title("generated speech")
plt.tight_layout()
plt.close()
figures = {
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
name_prefix + "speech_comparison": fig_wave,
}
return figures