mirror of https://github.com/coqui-ai/TTS.git
Add `do_amp_to_db` option
This commit is contained in:
parent
e94c1f894d
commit
060e746e21
|
@ -36,6 +36,10 @@ class BaseAudioConfig(Coqpit):
|
||||||
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
||||||
do_trim_silence (bool):
|
do_trim_silence (bool):
|
||||||
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
||||||
|
do_amp_to_db_linear (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||||
|
do_amp_to_db_mel (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||||
trim_db (int):
|
trim_db (int):
|
||||||
Silence threshold used for silence trimming. Defaults to 45.
|
Silence threshold used for silence trimming. Defaults to 45.
|
||||||
power (float):
|
power (float):
|
||||||
|
@ -91,6 +95,8 @@ class BaseAudioConfig(Coqpit):
|
||||||
mel_fmin: float = 0.0
|
mel_fmin: float = 0.0
|
||||||
mel_fmax: float = None
|
mel_fmax: float = None
|
||||||
spec_gain: int = 20
|
spec_gain: int = 20
|
||||||
|
do_amp_to_db_linear: bool = True
|
||||||
|
do_amp_to_db_mel: bool = True
|
||||||
# normalization params
|
# normalization params
|
||||||
signal_norm: bool = True
|
signal_norm: bool = True
|
||||||
min_level_db: int = -100
|
min_level_db: int = -100
|
||||||
|
|
|
@ -14,7 +14,10 @@ from TTS.tts.utils.data import StandardScaler
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""TODO: Merge this with audio.py"""
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||||
|
|
||||||
|
TODO: Merge this with audio.py
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -28,6 +31,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
mel_fmax=None,
|
mel_fmax=None,
|
||||||
n_mels=80,
|
n_mels=80,
|
||||||
use_mel=False,
|
use_mel=False,
|
||||||
|
do_amp_to_db=False,
|
||||||
|
spec_gain=1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
|
@ -39,6 +44,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
self.mel_fmax = mel_fmax
|
self.mel_fmax = mel_fmax
|
||||||
self.n_mels = n_mels
|
self.n_mels = n_mels
|
||||||
self.use_mel = use_mel
|
self.use_mel = use_mel
|
||||||
|
self.do_amp_to_db = do_amp_to_db
|
||||||
|
self.spec_gain = spec_gain
|
||||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||||
self.mel_basis = None
|
self.mel_basis = None
|
||||||
if use_mel:
|
if use_mel:
|
||||||
|
@ -79,6 +86,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||||
if self.use_mel:
|
if self.use_mel:
|
||||||
S = torch.matmul(self.mel_basis.to(x), S)
|
S = torch.matmul(self.mel_basis.to(x), S)
|
||||||
|
if self.do_amp_to_db:
|
||||||
|
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||||
return S
|
return S
|
||||||
|
|
||||||
def _build_mel_basis(self):
|
def _build_mel_basis(self):
|
||||||
|
@ -87,6 +96,12 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
)
|
)
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
|
def _amp_to_db(self, x, spec_gain=1.0):
|
||||||
|
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||||
|
|
||||||
|
def _db_to_amp(self, x, spec_gain=1.0):
|
||||||
|
return torch.exp(x) / spec_gain
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
class AudioProcessor(object):
|
class AudioProcessor(object):
|
||||||
|
@ -97,33 +112,93 @@ class AudioProcessor(object):
|
||||||
of the class with the model config. They are not meaningful for all the arguments.
|
of the class with the model config. They are not meaningful for all the arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_rate (int, optional): target audio sampling rate. Defaults to None.
|
sample_rate (int, optional):
|
||||||
resample (bool, optional): enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
target audio sampling rate. Defaults to None.
|
||||||
num_mels (int, optional): number of melspectrogram dimensions. Defaults to None.
|
|
||||||
log_func (int, optional): log exponent used for converting spectrogram aplitude to DB.
|
resample (bool, optional):
|
||||||
min_level_db (int, optional): minimum db threshold for the computed melspectrograms. Defaults to None.
|
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
||||||
frame_shift_ms (int, optional): milliseconds of frames between STFT columns. Defaults to None.
|
|
||||||
frame_length_ms (int, optional): milliseconds of STFT window length. Defaults to None.
|
num_mels (int, optional):
|
||||||
hop_length (int, optional): number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
number of melspectrogram dimensions. Defaults to None.
|
||||||
win_length (int, optional): STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
|
||||||
ref_level_db (int, optional): reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
log_func (int, optional):
|
||||||
fft_size (int, optional): FFT window size for STFT. Defaults to 1024.
|
log exponent used for converting spectrogram aplitude to DB.
|
||||||
power (int, optional): Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
|
||||||
preemphasis (float, optional): Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
min_level_db (int, optional):
|
||||||
signal_norm (bool, optional): enable/disable signal normalization. Defaults to None.
|
minimum db threshold for the computed melspectrograms. Defaults to None.
|
||||||
symmetric_norm (bool, optional): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
|
||||||
max_norm (float, optional): ```k``` defining the normalization range. Defaults to None.
|
frame_shift_ms (int, optional):
|
||||||
mel_fmin (int, optional): minimum filter frequency for computing melspectrograms. Defaults to None.
|
milliseconds of frames between STFT columns. Defaults to None.
|
||||||
mel_fmax (int, optional): maximum filter frequency for computing melspectrograms.. Defaults to None.
|
|
||||||
spec_gain (int, optional): gain applied when converting amplitude to DB. Defaults to 20.
|
frame_length_ms (int, optional):
|
||||||
stft_pad_mode (str, optional): Padding mode for STFT. Defaults to 'reflect'.
|
milliseconds of STFT window length. Defaults to None.
|
||||||
clip_norm (bool, optional): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
|
||||||
griffin_lim_iters (int, optional): Number of GriffinLim iterations. Defaults to None.
|
hop_length (int, optional):
|
||||||
do_trim_silence (bool, optional): enable/disable silence trimming when loading the audio signal. Defaults to False.
|
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
||||||
trim_db (int, optional): DB threshold used for silence trimming. Defaults to 60.
|
|
||||||
do_sound_norm (bool, optional): enable/disable signal normalization. Defaults to False.
|
win_length (int, optional):
|
||||||
stats_path (str, optional): Path to the computed stats file. Defaults to None.
|
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
||||||
verbose (bool, optional): enable/disable logging. Defaults to True.
|
|
||||||
|
ref_level_db (int, optional):
|
||||||
|
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
||||||
|
|
||||||
|
fft_size (int, optional):
|
||||||
|
FFT window size for STFT. Defaults to 1024.
|
||||||
|
|
||||||
|
power (int, optional):
|
||||||
|
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
||||||
|
|
||||||
|
preemphasis (float, optional):
|
||||||
|
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
||||||
|
|
||||||
|
signal_norm (bool, optional):
|
||||||
|
enable/disable signal normalization. Defaults to None.
|
||||||
|
|
||||||
|
symmetric_norm (bool, optional):
|
||||||
|
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
||||||
|
|
||||||
|
max_norm (float, optional):
|
||||||
|
```k``` defining the normalization range. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmin (int, optional):
|
||||||
|
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmax (int, optional):
|
||||||
|
maximum filter frequency for computing melspectrograms.. Defaults to None.
|
||||||
|
|
||||||
|
spec_gain (int, optional):
|
||||||
|
gain applied when converting amplitude to DB. Defaults to 20.
|
||||||
|
|
||||||
|
stft_pad_mode (str, optional):
|
||||||
|
Padding mode for STFT. Defaults to 'reflect'.
|
||||||
|
|
||||||
|
clip_norm (bool, optional):
|
||||||
|
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||||
|
|
||||||
|
griffin_lim_iters (int, optional):
|
||||||
|
Number of GriffinLim iterations. Defaults to None.
|
||||||
|
|
||||||
|
do_trim_silence (bool, optional):
|
||||||
|
enable/disable silence trimming when loading the audio signal. Defaults to False.
|
||||||
|
|
||||||
|
trim_db (int, optional):
|
||||||
|
DB threshold used for silence trimming. Defaults to 60.
|
||||||
|
|
||||||
|
do_sound_norm (bool, optional):
|
||||||
|
enable/disable signal normalization. Defaults to False.
|
||||||
|
|
||||||
|
do_amp_to_db_linear (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||||
|
|
||||||
|
do_amp_to_db_mel (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||||
|
|
||||||
|
stats_path (str, optional):
|
||||||
|
Path to the computed stats file. Defaults to None.
|
||||||
|
|
||||||
|
verbose (bool, optional):
|
||||||
|
enable/disable logging. Defaults to True.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -153,6 +228,8 @@ class AudioProcessor(object):
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
trim_db=60,
|
trim_db=60,
|
||||||
do_sound_norm=False,
|
do_sound_norm=False,
|
||||||
|
do_amp_to_db_linear=True,
|
||||||
|
do_amp_to_db_mel=True,
|
||||||
stats_path=None,
|
stats_path=None,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
**_,
|
**_,
|
||||||
|
@ -181,6 +258,8 @@ class AudioProcessor(object):
|
||||||
self.do_trim_silence = do_trim_silence
|
self.do_trim_silence = do_trim_silence
|
||||||
self.trim_db = trim_db
|
self.trim_db = trim_db
|
||||||
self.do_sound_norm = do_sound_norm
|
self.do_sound_norm = do_sound_norm
|
||||||
|
self.do_amp_to_db_linear = do_amp_to_db_linear
|
||||||
|
self.do_amp_to_db_mel = do_amp_to_db_mel
|
||||||
self.stats_path = stats_path
|
self.stats_path = stats_path
|
||||||
# setup exp_func for db to amp conversion
|
# setup exp_func for db to amp conversion
|
||||||
if log_func == "np.log":
|
if log_func == "np.log":
|
||||||
|
@ -381,7 +460,6 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Decibels spectrogram.
|
np.ndarray: Decibels spectrogram.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
|
@ -448,7 +526,10 @@ class AudioProcessor(object):
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
else:
|
else:
|
||||||
D = self._stft(y)
|
D = self._stft(y)
|
||||||
|
if self.do_amp_to_db_linear:
|
||||||
S = self._amp_to_db(np.abs(D))
|
S = self._amp_to_db(np.abs(D))
|
||||||
|
else:
|
||||||
|
S = np.abs(D)
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
|
||||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||||
|
@ -457,7 +538,10 @@ class AudioProcessor(object):
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
else:
|
else:
|
||||||
D = self._stft(y)
|
D = self._stft(y)
|
||||||
|
if self.do_amp_to_db_mel:
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||||
|
else:
|
||||||
|
S = self._linear_to_mel(np.abs(D))
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
|
||||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||||
|
|
Loading…
Reference in New Issue