refactor(audio.processor): remove duplicate stft+griffin_lim

This commit is contained in:
Enno Hermann 2023-11-13 21:47:36 +01:00
parent 8fa4de1c8c
commit d75879802a
2 changed files with 28 additions and 48 deletions

View File

@ -8,7 +8,7 @@ import scipy.signal
import soundfile as sf import soundfile as sf
from TTS.tts.utils.helpers import StandardScaler from TTS.tts.utils.helpers import StandardScaler
from TTS.utils.audio.numpy_transforms import compute_f0 from TTS.utils.audio.numpy_transforms import compute_f0, stft, griffin_lim
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
@ -460,9 +460,14 @@ class AudioProcessor(object):
np.ndarray: Spectrogram. np.ndarray: Spectrogram.
""" """
if self.preemphasis != 0: if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y)) y = self.apply_preemphasis(y)
else: D = stft(
D = self._stft(y) y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_linear: if self.do_amp_to_db_linear:
S = self._amp_to_db(np.abs(D)) S = self._amp_to_db(np.abs(D))
else: else:
@ -472,9 +477,14 @@ class AudioProcessor(object):
def melspectrogram(self, y: np.ndarray) -> np.ndarray: def melspectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a melspectrogram from a waveform.""" """Compute a melspectrogram from a waveform."""
if self.preemphasis != 0: if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y)) y = self.apply_preemphasis(y)
else: D = stft(
D = self._stft(y) y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_mel: if self.do_amp_to_db_mel:
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
else: else:
@ -486,18 +496,16 @@ class AudioProcessor(object):
S = self.denormalize(spectrogram) S = self.denormalize(spectrogram)
S = self._db_to_amp(S) S = self._db_to_amp(S)
# Reconstruct phase # Reconstruct phase
if self.preemphasis != 0: W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
return self._griffin_lim(S**self.power)
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" """Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
D = self.denormalize(mel_spectrogram) D = self.denormalize(mel_spectrogram)
S = self._db_to_amp(D) S = self._db_to_amp(D)
S = self._mel_to_linear(S) # Convert back to linear S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0: W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
return self._griffin_lim(S**self.power)
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram. """Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -515,45 +523,16 @@ class AudioProcessor(object):
mel = self.normalize(S) mel = self.normalize(S)
return mel return mel
### STFT and ISTFT ### def _griffin_lim(self, S):
def _stft(self, y: np.ndarray) -> np.ndarray: return griffin_lim(
"""Librosa STFT wrapper. spec=S,
num_iter=self.griffin_lim_iters,
Args:
y (np.ndarray): Audio signal.
Returns:
np.ndarray: Complex number array.
"""
return librosa.stft(
y=y,
n_fft=self.fft_size,
hop_length=self.hop_length, hop_length=self.hop_length,
win_length=self.win_length, win_length=self.win_length,
fft_size=self.fft_size,
pad_mode=self.stft_pad_mode, pad_mode=self.stft_pad_mode,
window="hann",
center=True,
) )
def _istft(self, y: np.ndarray) -> np.ndarray:
"""Librosa iSTFT wrapper."""
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
try:
S_complex = np.abs(S).astype(np.complex)
except AttributeError: # np.complex is deprecated since numpy 1.20.0
S_complex = np.abs(S).astype(complex)
y = self._istft(S_complex * angles)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)
return y
def compute_stft_paddings(self, x, pad_sides=1): def compute_stft_paddings(self, x, pad_sides=1):
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)""" (first and final frames)"""

View File

@ -5,6 +5,7 @@ import torch
from tests import get_tests_input_path, get_tests_output_path, get_tests_path from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.config import BaseAudioConfig from TTS.config import BaseAudioConfig
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import stft
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
TESTS_PATH = get_tests_path() TESTS_PATH = get_tests_path()
@ -21,7 +22,7 @@ def test_torch_stft():
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
# librosa stft # librosa stft
wav = ap.load_wav(WAV_FILE) wav = ap.load_wav(WAV_FILE)
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
# torch stft # torch stft
wav = torch.from_numpy(wav[None, :]).float() wav = torch.from_numpy(wav[None, :]).float()
M_torch = torch_stft(wav) M_torch = torch_stft(wav)