Add `do_amp_to_db` option

This commit is contained in:
Eren Gölge 2021-08-09 07:56:11 +00:00
parent e94c1f894d
commit 060e746e21
2 changed files with 121 additions and 31 deletions

View File

@ -36,6 +36,10 @@ class BaseAudioConfig(Coqpit):
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
do_trim_silence (bool):
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
trim_db (int):
Silence threshold used for silence trimming. Defaults to 45.
power (float):
@ -91,6 +95,8 @@ class BaseAudioConfig(Coqpit):
mel_fmin: float = 0.0
mel_fmax: float = None
spec_gain: int = 20
do_amp_to_db_linear: bool = True
do_amp_to_db_mel: bool = True
# normalization params
signal_norm: bool = True
min_level_db: int = -100

View File

@ -14,7 +14,10 @@ from TTS.tts.utils.data import StandardScaler
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py"""
"""Some of the audio processing funtions using Torch for faster batch processing.
TODO: Merge this with audio.py
"""
def __init__(
self,
@ -28,6 +31,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
):
super().__init__()
self.n_fft = n_fft
@ -39,6 +44,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
@ -79,6 +86,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
@ -87,6 +96,12 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
)
self.mel_basis = torch.from_numpy(mel_basis).float()
def _amp_to_db(self, x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
def _db_to_amp(self, x, spec_gain=1.0):
return torch.exp(x) / spec_gain
# pylint: disable=too-many-public-methods
class AudioProcessor(object):
@ -97,33 +112,93 @@ class AudioProcessor(object):
of the class with the model config. They are not meaningful for all the arguments.
Args:
sample_rate (int, optional): target audio sampling rate. Defaults to None.
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB.
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.
hop_length (int, optional): number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
win_length (int, optional): STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
ref_level_db (int, optional): reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
fft_size (int, optional): FFT window size for STFT. Defaults to 1024.
power (int, optional): Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
preemphasis (float, optional): Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
signal_norm (bool, optional): enable/disable signal normalization. Defaults to None.
symmetric_norm (bool, optional): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
max_norm (float, optional): ```k``` defining the normalization range. Defaults to None.
mel_fmin (int, optional): minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional): maximum filter frequency for computing melspectrograms.. Defaults to None.
spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20.
stft_pad_mode (str, optional): Padding mode for STFT. Defaults to 'reflect'.
clip_norm (bool, optional): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
griffin_lim_iters (int, optional): Number of GriffinLim iterations. Defaults to None.
do_trim_silence (bool, optional): enable/disable silence trimming when loading the audio signal. Defaults to False.
trim_db (int, optional): DB threshold used for silence trimming. Defaults to 60.
do_sound_norm (bool, optional): enable/disable signal normalization. Defaults to False.
stats_path (str, optional): Path to the computed stats file. Defaults to None.
verbose (bool, optional): enable/disable logging. Defaults to True.
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
resample (bool, optional):
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
num_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
log_func (int, optional):
log exponent used for converting spectrogram aplitude to DB.
min_level_db (int, optional):
minimum db threshold for the computed melspectrograms. Defaults to None.
frame_shift_ms (int, optional):
milliseconds of frames between STFT columns. Defaults to None.
frame_length_ms (int, optional):
milliseconds of STFT window length. Defaults to None.
hop_length (int, optional):
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
win_length (int, optional):
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
ref_level_db (int, optional):
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
fft_size (int, optional):
FFT window size for STFT. Defaults to 1024.
power (int, optional):
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
preemphasis (float, optional):
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
signal_norm (bool, optional):
enable/disable signal normalization. Defaults to None.
symmetric_norm (bool, optional):
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
max_norm (float, optional):
```k``` defining the normalization range. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms.. Defaults to None.
spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20.
stft_pad_mode (str, optional):
Padding mode for STFT. Defaults to 'reflect'.
clip_norm (bool, optional):
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
griffin_lim_iters (int, optional):
Number of GriffinLim iterations. Defaults to None.
do_trim_silence (bool, optional):
enable/disable silence trimming when loading the audio signal. Defaults to False.
trim_db (int, optional):
DB threshold used for silence trimming. Defaults to 60.
do_sound_norm (bool, optional):
enable/disable signal normalization. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
verbose (bool, optional):
enable/disable logging. Defaults to True.
"""
def __init__(
@ -153,6 +228,8 @@ class AudioProcessor(object):
do_trim_silence=False,
trim_db=60,
do_sound_norm=False,
do_amp_to_db_linear=True,
do_amp_to_db_mel=True,
stats_path=None,
verbose=True,
**_,
@ -181,6 +258,8 @@ class AudioProcessor(object):
self.do_trim_silence = do_trim_silence
self.trim_db = trim_db
self.do_sound_norm = do_sound_norm
self.do_amp_to_db_linear = do_amp_to_db_linear
self.do_amp_to_db_mel = do_amp_to_db_mel
self.stats_path = stats_path
# setup exp_func for db to amp conversion
if log_func == "np.log":
@ -381,7 +460,6 @@ class AudioProcessor(object):
Returns:
np.ndarray: Decibels spectrogram.
"""
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
# pylint: disable=no-self-use
@ -448,7 +526,10 @@ class AudioProcessor(object):
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
if self.do_amp_to_db_linear:
S = self._amp_to_db(np.abs(D))
else:
S = np.abs(D)
return self.normalize(S).astype(np.float32)
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
@ -457,7 +538,10 @@ class AudioProcessor(object):
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
if self.do_amp_to_db_mel:
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
else:
S = self._linear_to_mel(np.abs(D))
return self.normalize(S).astype(np.float32)
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: