mirror of https://github.com/coqui-ai/TTS.git
refactor(audio.processor): remove duplicate stft+griffin_lim
This commit is contained in:
parent
8fa4de1c8c
commit
d75879802a
|
@ -8,7 +8,7 @@ import scipy.signal
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import StandardScaler
|
from TTS.tts.utils.helpers import StandardScaler
|
||||||
from TTS.utils.audio.numpy_transforms import compute_f0
|
from TTS.utils.audio.numpy_transforms import compute_f0, stft, griffin_lim
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
|
@ -460,9 +460,14 @@ class AudioProcessor(object):
|
||||||
np.ndarray: Spectrogram.
|
np.ndarray: Spectrogram.
|
||||||
"""
|
"""
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
y = self.apply_preemphasis(y)
|
||||||
else:
|
D = stft(
|
||||||
D = self._stft(y)
|
y=y,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
pad_mode=self.stft_pad_mode,
|
||||||
|
)
|
||||||
if self.do_amp_to_db_linear:
|
if self.do_amp_to_db_linear:
|
||||||
S = self._amp_to_db(np.abs(D))
|
S = self._amp_to_db(np.abs(D))
|
||||||
else:
|
else:
|
||||||
|
@ -472,9 +477,14 @@ class AudioProcessor(object):
|
||||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||||
"""Compute a melspectrogram from a waveform."""
|
"""Compute a melspectrogram from a waveform."""
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
y = self.apply_preemphasis(y)
|
||||||
else:
|
D = stft(
|
||||||
D = self._stft(y)
|
y=y,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
pad_mode=self.stft_pad_mode,
|
||||||
|
)
|
||||||
if self.do_amp_to_db_mel:
|
if self.do_amp_to_db_mel:
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||||
else:
|
else:
|
||||||
|
@ -486,18 +496,16 @@ class AudioProcessor(object):
|
||||||
S = self.denormalize(spectrogram)
|
S = self.denormalize(spectrogram)
|
||||||
S = self._db_to_amp(S)
|
S = self._db_to_amp(S)
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
if self.preemphasis != 0:
|
W = self._griffin_lim(S**self.power)
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||||
return self._griffin_lim(S**self.power)
|
|
||||||
|
|
||||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
D = self.denormalize(mel_spectrogram)
|
D = self.denormalize(mel_spectrogram)
|
||||||
S = self._db_to_amp(D)
|
S = self._db_to_amp(D)
|
||||||
S = self._mel_to_linear(S) # Convert back to linear
|
S = self._mel_to_linear(S) # Convert back to linear
|
||||||
if self.preemphasis != 0:
|
W = self._griffin_lim(S**self.power)
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||||
return self._griffin_lim(S**self.power)
|
|
||||||
|
|
||||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||||
|
@ -515,45 +523,16 @@ class AudioProcessor(object):
|
||||||
mel = self.normalize(S)
|
mel = self.normalize(S)
|
||||||
return mel
|
return mel
|
||||||
|
|
||||||
### STFT and ISTFT ###
|
def _griffin_lim(self, S):
|
||||||
def _stft(self, y: np.ndarray) -> np.ndarray:
|
return griffin_lim(
|
||||||
"""Librosa STFT wrapper.
|
spec=S,
|
||||||
|
num_iter=self.griffin_lim_iters,
|
||||||
Args:
|
|
||||||
y (np.ndarray): Audio signal.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Complex number array.
|
|
||||||
"""
|
|
||||||
return librosa.stft(
|
|
||||||
y=y,
|
|
||||||
n_fft=self.fft_size,
|
|
||||||
hop_length=self.hop_length,
|
hop_length=self.hop_length,
|
||||||
win_length=self.win_length,
|
win_length=self.win_length,
|
||||||
|
fft_size=self.fft_size,
|
||||||
pad_mode=self.stft_pad_mode,
|
pad_mode=self.stft_pad_mode,
|
||||||
window="hann",
|
|
||||||
center=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _istft(self, y: np.ndarray) -> np.ndarray:
|
|
||||||
"""Librosa iSTFT wrapper."""
|
|
||||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
|
||||||
|
|
||||||
def _griffin_lim(self, S):
|
|
||||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
|
||||||
try:
|
|
||||||
S_complex = np.abs(S).astype(np.complex)
|
|
||||||
except AttributeError: # np.complex is deprecated since numpy 1.20.0
|
|
||||||
S_complex = np.abs(S).astype(complex)
|
|
||||||
y = self._istft(S_complex * angles)
|
|
||||||
if not np.isfinite(y).all():
|
|
||||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
|
||||||
return np.array([0.0])
|
|
||||||
for _ in range(self.griffin_lim_iters):
|
|
||||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
|
||||||
y = self._istft(S_complex * angles)
|
|
||||||
return y
|
|
||||||
|
|
||||||
def compute_stft_paddings(self, x, pad_sides=1):
|
def compute_stft_paddings(self, x, pad_sides=1):
|
||||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||||
(first and final frames)"""
|
(first and final frames)"""
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||||
from TTS.config import BaseAudioConfig
|
from TTS.config import BaseAudioConfig
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import stft
|
||||||
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
|
@ -21,7 +22,7 @@ def test_torch_stft():
|
||||||
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
||||||
# librosa stft
|
# librosa stft
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
|
||||||
# torch stft
|
# torch stft
|
||||||
wav = torch.from_numpy(wav[None, :]).float()
|
wav = torch.from_numpy(wav[None, :]).float()
|
||||||
M_torch = torch_stft(wav)
|
M_torch = torch_stft(wav)
|
||||||
|
|
Loading…
Reference in New Issue