from typing import Dict import numpy as np import torch from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel from TTS.utils.audio.processor import AudioProcessor def interpolate_vocoder_input(scale_factor, spec): """Interpolate spectrogram by the scale factor. It is mainly used to match the sampling rates of the tts and vocoder models. Args: scale_factor (float): scale factor to interpolate the spectrogram spec (np.array): spectrogram to be interpolated Returns: torch.tensor: interpolated spectrogram. """ print(" > before interpolation :", spec.shape) spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable spec = torch.nn.functional.interpolate( spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False ).squeeze(0) print(" > after interpolation :", spec.shape) return spec def plot_results( y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor = None, audio_config: "Coqpit" = None, name_prefix: str = None, ) -> Dict: """Plot the predicted and the real waveform and their spectrograms. Args: y_hat (torch.tensor): Predicted waveform. y (torch.tensor): Real waveform. ap (AudioProcessor): Audio processor used to process the waveform. Defaults to None. audio_config (Coqpit): Audio configuration. Only used when ```ap``` is None. Defaults to None. name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. Returns: Dict: output figures keyed by the name of the figures. """ """Plot vocoder model results""" if name_prefix is None: name_prefix = "" # select an instance from batch y_hat = y_hat[0].squeeze().detach().cpu().numpy() y = y[0].squeeze().detach().cpu().numpy() if ap is not None: spec_fake = ap.melspectrogram(y_hat).T spec_real = ap.melspectrogram(y).T elif audio_config is not None: mel_basis = build_mel_basis(**audio_config) spec_fake = wav_to_mel(y=y_hat, mel_basis=mel_basis, **audio_config).T spec_real = wav_to_mel(y=y, mel_basis=mel_basis, **audio_config).T spec_fake = amp_to_db(x=spec_fake, gain=1.0, base=10.0) spec_real = amp_to_db(x=spec_real, gain=1.0, base=10.0) else: raise ValueError(" [!] Either `ap` or `audio_config` must be provided.") spec_diff = np.abs(spec_fake - spec_real) # plot figure and save it fig_wave = plt.figure() plt.subplot(2, 1, 1) plt.plot(y) plt.title("groundtruth speech") plt.subplot(2, 1, 2) plt.plot(y_hat) plt.title("generated speech") plt.tight_layout() plt.close() figures = { name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake), name_prefix + "spectrogram/real": plot_spectrogram(spec_real), name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff), name_prefix + "speech_comparison": fig_wave, } return figures