Make plot results more general

This commit is contained in:
Eren Gölge 2022-04-19 09:20:31 +02:00 committed by Eren G??lge
parent e7c5db0d97
commit cc57c20162
1 changed files with 16 additions and 5 deletions

View File

@ -3,9 +3,10 @@ from typing import Dict
import numpy as np
import torch
from matplotlib import pyplot as plt
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.processor import AudioProcessor
def interpolate_vocoder_input(scale_factor, spec):
@ -29,13 +30,14 @@ def interpolate_vocoder_input(scale_factor, spec):
return spec
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor=None, audio_config: "Coqpit"= None, name_prefix: str = None) -> Dict:
"""Plot the predicted and the real waveform and their spectrograms.
Args:
y_hat (torch.tensor): Predicted waveform.
y (torch.tensor): Real waveform.
ap (AudioProcessor): Audio processor used to process the waveform.
ap (AudioProcessor): Audio processor used to process the waveform. Defaults to None.
audio_config (Coqpit): Audio configuration. Only used when ```ap``` is None. Defaults to None.
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
Returns:
@ -48,8 +50,17 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
y = y[0].squeeze().detach().cpu().numpy()
spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T
if ap is not None:
spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T
elif audio_config is not None:
mel_basis = build_mel_basis(**audio_config)
spec_fake = wav_to_mel(y=y_hat, mel_basis=mel_basis, **audio_config).T
spec_real = wav_to_mel(y=y, mel_basis=mel_basis, **audio_config).T
spec_fake = amp_to_db(x=spec_fake, gain=1.0, base=10.0)
spec_real = amp_to_db(x=spec_real, gain=1.0, base=10.0)
else:
raise ValueError(" [!] Either `ap` or `audio_config` must be provided.")
spec_diff = np.abs(spec_fake - spec_real)
# plot figure and save it