refactor(audio.processor): remove duplicate stft+griffin_lim

This commit is contained in:
Enno Hermann 2023-11-13 21:47:36 +01:00
parent 8fa4de1c8c
commit d75879802a
2 changed files with 28 additions and 48 deletions

View File

@ -8,7 +8,7 @@ import scipy.signal
import soundfile as sf
from TTS.tts.utils.helpers import StandardScaler
from TTS.utils.audio.numpy_transforms import compute_f0
from TTS.utils.audio.numpy_transforms import compute_f0, stft, griffin_lim
# pylint: disable=too-many-public-methods
@ -460,9 +460,14 @@ class AudioProcessor(object):
np.ndarray: Spectrogram.
"""
if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
y = self.apply_preemphasis(y)
D = stft(
y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_linear:
S = self._amp_to_db(np.abs(D))
else:
@ -472,9 +477,14 @@ class AudioProcessor(object):
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a melspectrogram from a waveform."""
if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
y = self.apply_preemphasis(y)
D = stft(
y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_mel:
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
else:
@ -486,18 +496,16 @@ class AudioProcessor(object):
S = self.denormalize(spectrogram)
S = self._db_to_amp(S)
# Reconstruct phase
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
D = self.denormalize(mel_spectrogram)
S = self._db_to_amp(D)
S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -515,45 +523,16 @@ class AudioProcessor(object):
mel = self.normalize(S)
return mel
### STFT and ISTFT ###
def _stft(self, y: np.ndarray) -> np.ndarray:
"""Librosa STFT wrapper.
Args:
y (np.ndarray): Audio signal.
Returns:
np.ndarray: Complex number array.
"""
return librosa.stft(
y=y,
n_fft=self.fft_size,
def _griffin_lim(self, S):
return griffin_lim(
spec=S,
num_iter=self.griffin_lim_iters,
hop_length=self.hop_length,
win_length=self.win_length,
fft_size=self.fft_size,
pad_mode=self.stft_pad_mode,
window="hann",
center=True,
)
def _istft(self, y: np.ndarray) -> np.ndarray:
"""Librosa iSTFT wrapper."""
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
try:
S_complex = np.abs(S).astype(np.complex)
except AttributeError: # np.complex is deprecated since numpy 1.20.0
S_complex = np.abs(S).astype(complex)
y = self._istft(S_complex * angles)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)
return y
def compute_stft_paddings(self, x, pad_sides=1):
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)"""

View File

@ -5,6 +5,7 @@ import torch
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.config import BaseAudioConfig
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import stft
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
TESTS_PATH = get_tests_path()
@ -21,7 +22,7 @@ def test_torch_stft():
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
# librosa stft
wav = ap.load_wav(WAV_FILE)
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
# torch stft
wav = torch.from_numpy(wav[None, :]).float()
M_torch = torch_stft(wav)