mirror of https://github.com/coqui-ai/TTS.git
Make plot results more general
This commit is contained in:
parent
e7c5db0d97
commit
cc57c20162
|
@ -3,9 +3,10 @@ from typing import Dict
|
|||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
|
||||
def interpolate_vocoder_input(scale_factor, spec):
|
||||
|
@ -29,13 +30,14 @@ def interpolate_vocoder_input(scale_factor, spec):
|
|||
return spec
|
||||
|
||||
|
||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor=None, audio_config: "Coqpit"= None, name_prefix: str = None) -> Dict:
|
||||
"""Plot the predicted and the real waveform and their spectrograms.
|
||||
|
||||
Args:
|
||||
y_hat (torch.tensor): Predicted waveform.
|
||||
y (torch.tensor): Real waveform.
|
||||
ap (AudioProcessor): Audio processor used to process the waveform.
|
||||
ap (AudioProcessor): Audio processor used to process the waveform. Defaults to None.
|
||||
audio_config (Coqpit): Audio configuration. Only used when ```ap``` is None. Defaults to None.
|
||||
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||
|
||||
Returns:
|
||||
|
@ -48,8 +50,17 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_
|
|||
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||
y = y[0].squeeze().detach().cpu().numpy()
|
||||
|
||||
spec_fake = ap.melspectrogram(y_hat).T
|
||||
spec_real = ap.melspectrogram(y).T
|
||||
if ap is not None:
|
||||
spec_fake = ap.melspectrogram(y_hat).T
|
||||
spec_real = ap.melspectrogram(y).T
|
||||
elif audio_config is not None:
|
||||
mel_basis = build_mel_basis(**audio_config)
|
||||
spec_fake = wav_to_mel(y=y_hat, mel_basis=mel_basis, **audio_config).T
|
||||
spec_real = wav_to_mel(y=y, mel_basis=mel_basis, **audio_config).T
|
||||
spec_fake = amp_to_db(x=spec_fake, gain=1.0, base=10.0)
|
||||
spec_real = amp_to_db(x=spec_real, gain=1.0, base=10.0)
|
||||
else:
|
||||
raise ValueError(" [!] Either `ap` or `audio_config` must be provided.")
|
||||
spec_diff = np.abs(spec_fake - spec_real)
|
||||
|
||||
# plot figure and save it
|
||||
|
|
Loading…
Reference in New Issue