mirror of https://github.com/coqui-ai/TTS.git
Make plot results more general
This commit is contained in:
parent
e7c5db0d97
commit
cc57c20162
|
@ -3,9 +3,10 @@ from typing import Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
def interpolate_vocoder_input(scale_factor, spec):
|
def interpolate_vocoder_input(scale_factor, spec):
|
||||||
|
@ -29,13 +30,14 @@ def interpolate_vocoder_input(scale_factor, spec):
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor=None, audio_config: "Coqpit"= None, name_prefix: str = None) -> Dict:
|
||||||
"""Plot the predicted and the real waveform and their spectrograms.
|
"""Plot the predicted and the real waveform and their spectrograms.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y_hat (torch.tensor): Predicted waveform.
|
y_hat (torch.tensor): Predicted waveform.
|
||||||
y (torch.tensor): Real waveform.
|
y (torch.tensor): Real waveform.
|
||||||
ap (AudioProcessor): Audio processor used to process the waveform.
|
ap (AudioProcessor): Audio processor used to process the waveform. Defaults to None.
|
||||||
|
audio_config (Coqpit): Audio configuration. Only used when ```ap``` is None. Defaults to None.
|
||||||
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -48,8 +50,17 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_
|
||||||
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||||
y = y[0].squeeze().detach().cpu().numpy()
|
y = y[0].squeeze().detach().cpu().numpy()
|
||||||
|
|
||||||
spec_fake = ap.melspectrogram(y_hat).T
|
if ap is not None:
|
||||||
spec_real = ap.melspectrogram(y).T
|
spec_fake = ap.melspectrogram(y_hat).T
|
||||||
|
spec_real = ap.melspectrogram(y).T
|
||||||
|
elif audio_config is not None:
|
||||||
|
mel_basis = build_mel_basis(**audio_config)
|
||||||
|
spec_fake = wav_to_mel(y=y_hat, mel_basis=mel_basis, **audio_config).T
|
||||||
|
spec_real = wav_to_mel(y=y, mel_basis=mel_basis, **audio_config).T
|
||||||
|
spec_fake = amp_to_db(x=spec_fake, gain=1.0, base=10.0)
|
||||||
|
spec_real = amp_to_db(x=spec_real, gain=1.0, base=10.0)
|
||||||
|
else:
|
||||||
|
raise ValueError(" [!] Either `ap` or `audio_config` must be provided.")
|
||||||
spec_diff = np.abs(spec_fake - spec_real)
|
spec_diff = np.abs(spec_fake - spec_real)
|
||||||
|
|
||||||
# plot figure and save it
|
# plot figure and save it
|
||||||
|
|
Loading…
Reference in New Issue