diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 63a0af44..98a0a939 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -3,9 +3,10 @@ from typing import Dict import numpy as np import torch from matplotlib import pyplot as plt +from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel from TTS.tts.utils.visual import plot_spectrogram -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor def interpolate_vocoder_input(scale_factor, spec): @@ -29,13 +30,14 @@ def interpolate_vocoder_input(scale_factor, spec): return spec -def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: +def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor=None, audio_config: "Coqpit"= None, name_prefix: str = None) -> Dict: """Plot the predicted and the real waveform and their spectrograms. Args: y_hat (torch.tensor): Predicted waveform. y (torch.tensor): Real waveform. - ap (AudioProcessor): Audio processor used to process the waveform. + ap (AudioProcessor): Audio processor used to process the waveform. Defaults to None. + audio_config (Coqpit): Audio configuration. Only used when ```ap``` is None. Defaults to None. name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. Returns: @@ -48,8 +50,17 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_ y_hat = y_hat[0].squeeze().detach().cpu().numpy() y = y[0].squeeze().detach().cpu().numpy() - spec_fake = ap.melspectrogram(y_hat).T - spec_real = ap.melspectrogram(y).T + if ap is not None: + spec_fake = ap.melspectrogram(y_hat).T + spec_real = ap.melspectrogram(y).T + elif audio_config is not None: + mel_basis = build_mel_basis(**audio_config) + spec_fake = wav_to_mel(y=y_hat, mel_basis=mel_basis, **audio_config).T + spec_real = wav_to_mel(y=y, mel_basis=mel_basis, **audio_config).T + spec_fake = amp_to_db(x=spec_fake, gain=1.0, base=10.0) + spec_real = amp_to_db(x=spec_real, gain=1.0, base=10.0) + else: + raise ValueError(" [!] Either `ap` or `audio_config` must be provided.") spec_diff = np.abs(spec_fake - spec_real) # plot figure and save it