From d75879802a276066a9002ed99851d073abeb6494 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 13 Nov 2023 21:47:36 +0100 Subject: [PATCH] refactor(audio.processor): remove duplicate stft+griffin_lim --- TTS/utils/audio/processor.py | 73 ++++++++-------------- tests/vocoder_tests/test_vocoder_losses.py | 3 +- 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index 4ceb7da4..b2991338 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -8,7 +8,7 @@ import scipy.signal import soundfile as sf from TTS.tts.utils.helpers import StandardScaler -from TTS.utils.audio.numpy_transforms import compute_f0 +from TTS.utils.audio.numpy_transforms import compute_f0, stft, griffin_lim # pylint: disable=too-many-public-methods @@ -460,9 +460,14 @@ class AudioProcessor(object): np.ndarray: Spectrogram. """ if self.preemphasis != 0: - D = self._stft(self.apply_preemphasis(y)) - else: - D = self._stft(y) + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) if self.do_amp_to_db_linear: S = self._amp_to_db(np.abs(D)) else: @@ -472,9 +477,14 @@ class AudioProcessor(object): def melspectrogram(self, y: np.ndarray) -> np.ndarray: """Compute a melspectrogram from a waveform.""" if self.preemphasis != 0: - D = self._stft(self.apply_preemphasis(y)) - else: - D = self._stft(y) + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) if self.do_amp_to_db_mel: S = self._amp_to_db(self._linear_to_mel(np.abs(D))) else: @@ -486,18 +496,16 @@ class AudioProcessor(object): S = self.denormalize(spectrogram) S = self._db_to_amp(S) # Reconstruct phase - if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" D = self.denormalize(mel_spectrogram) S = self._db_to_amp(D) S = self._mel_to_linear(S) # Convert back to linear - if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: """Convert a full scale linear spectrogram output of a network to a melspectrogram. @@ -515,45 +523,16 @@ class AudioProcessor(object): mel = self.normalize(S) return mel - ### STFT and ISTFT ### - def _stft(self, y: np.ndarray) -> np.ndarray: - """Librosa STFT wrapper. - - Args: - y (np.ndarray): Audio signal. - - Returns: - np.ndarray: Complex number array. - """ - return librosa.stft( - y=y, - n_fft=self.fft_size, + def _griffin_lim(self, S): + return griffin_lim( + spec=S, + num_iter=self.griffin_lim_iters, hop_length=self.hop_length, win_length=self.win_length, + fft_size=self.fft_size, pad_mode=self.stft_pad_mode, - window="hann", - center=True, ) - def _istft(self, y: np.ndarray) -> np.ndarray: - """Librosa iSTFT wrapper.""" - return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) - - def _griffin_lim(self, S): - angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) - try: - S_complex = np.abs(S).astype(np.complex) - except AttributeError: # np.complex is deprecated since numpy 1.20.0 - S_complex = np.abs(S).astype(complex) - y = self._istft(S_complex * angles) - if not np.isfinite(y).all(): - print(" [!] Waveform is not finite everywhere. Skipping the GL.") - return np.array([0.0]) - for _ in range(self.griffin_lim_iters): - angles = np.exp(1j * np.angle(self._stft(y))) - y = self._istft(S_complex * angles) - return y - def compute_stft_paddings(self, x, pad_sides=1): """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding (first and final frames)""" diff --git a/tests/vocoder_tests/test_vocoder_losses.py b/tests/vocoder_tests/test_vocoder_losses.py index 2a35aa2e..95501c2d 100644 --- a/tests/vocoder_tests/test_vocoder_losses.py +++ b/tests/vocoder_tests/test_vocoder_losses.py @@ -5,6 +5,7 @@ import torch from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.config import BaseAudioConfig from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import stft from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT TESTS_PATH = get_tests_path() @@ -21,7 +22,7 @@ def test_torch_stft(): torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) # librosa stft wav = ap.load_wav(WAV_FILE) - M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access + M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)) # torch stft wav = torch.from_numpy(wav[None, :]).float() M_torch = torch_stft(wav)