refactor(audio): improve type hints, address lint issues

This commit is contained in:
Enno Hermann 2024-11-17 16:40:43 +01:00
parent 48f5be2ccb
commit 5784f6705a
2 changed files with 89 additions and 101 deletions

View File

@ -1,6 +1,6 @@
import logging
from io import BytesIO
from typing import Tuple
from typing import Optional
import librosa
import numpy as np
@ -16,11 +16,11 @@ logger = logging.getLogger(__name__)
def build_mel_basis(
*,
sample_rate: int = None,
fft_size: int = None,
num_mels: int = None,
mel_fmax: int = None,
mel_fmin: int = None,
sample_rate: int,
fft_size: int,
num_mels: int,
mel_fmin: int,
mel_fmax: Optional[int] = None,
**kwargs,
) -> np.ndarray:
"""Build melspectrogram basis.
@ -34,9 +34,7 @@ def build_mel_basis(
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
def millisec_to_length(
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
) -> Tuple[int, int]:
def millisec_to_length(*, frame_length_ms: float, frame_shift_ms: float, sample_rate: int, **kwargs) -> tuple[int, int]:
"""Compute hop and window length from milliseconds.
Returns:
@ -61,7 +59,7 @@ def _exp(x, base):
return np.exp(x)
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
"""Convert amplitude values to decibels.
Args:
@ -77,7 +75,7 @@ def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs
# pylint: disable=no-self-use
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
def db_to_amp(*, x: np.ndarray, gain: float = 1, base: float = 10, **kwargs) -> np.ndarray:
"""Convert decibels spectrogram to amplitude spectrogram.
Args:
@ -104,18 +102,20 @@ def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
np.ndarray: Decorrelated audio signal.
"""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
msg = " [!] Preemphasis is set 0.0."
raise RuntimeError(msg)
return scipy.signal.lfilter([1, -coef], [1], x)
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
def deemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
"""Reverse pre-emphasis."""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
msg = " [!] Preemphasis is set 0.0."
raise ValueError(msg)
return scipy.signal.lfilter([1], [1, -coef], x)
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray, **kwargs) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
Args:
@ -130,14 +130,14 @@ def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) ->
return np.dot(mel_basis, spec)
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
def mel_to_spec(*, mel: np.ndarray, mel_basis: np.ndarray, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to full scale spectrogram."""
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
inv_mel_basis = np.linalg.pinv(mel_basis)
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
def wav_to_spec(*, wav: np.ndarray, **kwargs) -> np.ndarray:
"""Compute a spectrogram from a waveform.
Args:
@ -151,7 +151,7 @@ def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
return S.astype(np.float32)
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
def wav_to_mel(*, wav: np.ndarray, mel_basis: np.ndarray, **kwargs) -> np.ndarray:
"""Compute a melspectrogram from a waveform."""
D = stft(y=wav, **kwargs)
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
@ -164,20 +164,20 @@ def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray
return griffin_lim(spec=S**power, **kwargs)
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
def mel_to_wav(*, mel: np.ndarray, mel_basis: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
S = mel.copy()
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
S = mel_to_spec(mel=S, mel_basis=mel_basis) # Convert back to linear
return griffin_lim(spec=S**power, **kwargs)
### STFT and ISTFT ###
def stft(
*,
y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None,
win_length: int = None,
y: np.ndarray,
fft_size: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
pad_mode: str = "reflect",
window: str = "hann",
center: bool = True,
@ -203,9 +203,9 @@ def stft(
def istft(
*,
y: np.ndarray = None,
hop_length: int = None,
win_length: int = None,
y: np.ndarray,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: str = "hann",
center: bool = True,
**kwargs,
@ -220,7 +220,7 @@ def istft(
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
def griffin_lim(*, spec: np.ndarray, num_iter=60, **kwargs) -> np.ndarray:
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
S_complex = np.abs(spec).astype(complex)
y = istft(y=S_complex * angles, **kwargs)
@ -233,11 +233,11 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray
return y
def compute_stft_paddings(
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
) -> Tuple[int, int]:
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)"""
def compute_stft_paddings(*, x: np.ndarray, hop_length: int, pad_two_sides: bool = False, **kwargs) -> tuple[int, int]:
"""Compute paddings used by Librosa's STFT.
Compute right padding (final frame) or both sides padding (first and final frames).
"""
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
if not pad_two_sides:
return 0, pad
@ -246,12 +246,12 @@ def compute_stft_paddings(
def compute_f0(
*,
x: np.ndarray = None,
pitch_fmax: float = None,
pitch_fmin: float = None,
hop_length: int = None,
win_length: int = None,
sample_rate: int = None,
x: np.ndarray,
pitch_fmax: Optional[float] = None,
pitch_fmin: Optional[float] = None,
hop_length: int,
win_length: int,
sample_rate: int,
stft_pad_mode: str = "reflect",
center: bool = True,
**kwargs,
@ -323,19 +323,18 @@ def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
"""
x = stft(y=y, **kwargs)
mag, _ = magphase(x)
energy = np.sqrt(np.sum(mag**2, axis=0))
return energy
return np.sqrt(np.sum(mag**2, axis=0))
### Audio Processing ###
def find_endpoint(
*,
wav: np.ndarray = None,
wav: np.ndarray,
trim_db: float = -40,
sample_rate: int = None,
min_silence_sec=0.8,
gain: float = None,
base: int = None,
sample_rate: int,
min_silence_sec: float = 0.8,
gain: float = 1,
base: float = 10,
**kwargs,
) -> int:
"""Find the last point without silence at the end of a audio signal.
@ -344,8 +343,8 @@ def find_endpoint(
wav (np.ndarray): Audio signal.
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
gain (float, optional): Gain factor to be used to convert trim_db to trim_amp. Defaults to 1.
base (float, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
Returns:
int: Last point without silence.
@ -361,20 +360,20 @@ def find_endpoint(
def trim_silence(
*,
wav: np.ndarray = None,
sample_rate: int = None,
trim_db: float = None,
win_length: int = None,
hop_length: int = None,
wav: np.ndarray,
sample_rate: int,
trim_db: float = 60,
win_length: int,
hop_length: int,
**kwargs,
) -> np.ndarray:
"""Trim silent parts with a threshold and 0.01 sec margin"""
"""Trim silent parts with a threshold and 0.01 sec margin."""
margin = int(sample_rate * 0.01)
wav = wav[margin:-margin]
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
def volume_norm(*, x: np.ndarray, coef: float = 0.95, **kwargs) -> np.ndarray:
"""Normalize the volume of an audio signal.
Args:
@ -387,7 +386,7 @@ def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.nda
return x / abs(x).max() * coef
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
def rms_norm(*, wav: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
@ -404,11 +403,10 @@ def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.n
np.ndarray: RMS normalized waveform.
"""
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
wav = rms_norm(wav=x, db_level=db_level)
return wav
return rms_norm(wav=x, db_level=db_level)
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool = False, **kwargs) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
@ -433,13 +431,13 @@ def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False,
return x
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None:
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **kwargs) -> None:
"""Save float waveform to a file using Scipy.
Args:
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
path (str): Path to a output file.
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
sr (int): Sampling rate used for saving to the file. Defaults to None.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
"""
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
@ -465,8 +463,7 @@ def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
"""Recovers waveform from quantized values."""
mu = 2**mulaw_qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
return np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:

View File

@ -1,6 +1,6 @@
import logging
from io import BytesIO
from typing import Dict, Tuple
from typing import Optional
import librosa
import numpy as np
@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=too-many-public-methods
class AudioProcessor(object):
class AudioProcessor:
"""Audio Processor for TTS.
Note:
@ -172,7 +172,7 @@ class AudioProcessor(object):
db_level=None,
stats_path=None,
**_,
):
) -> None:
# setup class attributed
self.sample_rate = sample_rate
self.resample = resample
@ -210,7 +210,8 @@ class AudioProcessor(object):
elif log_func == "np.log10":
self.base = 10
else:
raise ValueError(" [!] unknown `log_func` value.")
msg = " [!] unknown `log_func` value."
raise ValueError(msg)
# setup stft parameters
if hop_length is None:
# compute stft parameters from given time values
@ -254,7 +255,7 @@ class AudioProcessor(object):
### normalization ###
def normalize(self, S: np.ndarray) -> np.ndarray:
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`.
Args:
S (np.ndarray): Spectrogram to normalize.
@ -272,10 +273,10 @@ class AudioProcessor(object):
if hasattr(self, "mel_scaler"):
if S.shape[0] == self.num_mels:
return self.mel_scaler.transform(S.T).T
elif S.shape[0] == self.fft_size / 2:
if S.shape[0] == self.fft_size / 2:
return self.linear_scaler.transform(S.T).T
else:
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
msg = " [!] Mean-Var stats does not match the given feature dimensions."
raise RuntimeError(msg)
# range normalization
S -= self.ref_level_db # discard certain range of DB assuming it is air noise
S_norm = (S - self.min_level_db) / (-self.min_level_db)
@ -286,13 +287,11 @@ class AudioProcessor(object):
S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
)
return S_norm
else:
S_norm = self.max_norm * S_norm
if self.clip_norm:
S_norm = np.clip(S_norm, 0, self.max_norm)
return S_norm
else:
return S
S_norm = self.max_norm * S_norm
if self.clip_norm:
S_norm = np.clip(S_norm, 0, self.max_norm)
return S_norm
return S
def denormalize(self, S: np.ndarray) -> np.ndarray:
"""Denormalize spectrogram values.
@ -313,10 +312,10 @@ class AudioProcessor(object):
if hasattr(self, "mel_scaler"):
if S_denorm.shape[0] == self.num_mels:
return self.mel_scaler.inverse_transform(S_denorm.T).T
elif S_denorm.shape[0] == self.fft_size / 2:
if S_denorm.shape[0] == self.fft_size / 2:
return self.linear_scaler.inverse_transform(S_denorm.T).T
else:
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
msg = " [!] Mean-Var stats does not match the given feature dimensions."
raise RuntimeError(msg)
if self.symmetric_norm:
if self.clip_norm:
S_denorm = np.clip(
@ -324,16 +323,14 @@ class AudioProcessor(object):
)
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
return S_denorm + self.ref_level_db
else:
if self.clip_norm:
S_denorm = np.clip(S_denorm, 0, self.max_norm)
S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
return S_denorm + self.ref_level_db
else:
return S_denorm
if self.clip_norm:
S_denorm = np.clip(S_denorm, 0, self.max_norm)
S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
return S_denorm + self.ref_level_db
return S_denorm
### Mean-STD scaling ###
def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]:
def load_stats(self, stats_path: str) -> tuple[np.array, np.array, np.array, np.array, dict]:
"""Loading mean and variance statistics from a `npy` file.
Args:
@ -351,7 +348,7 @@ class AudioProcessor(object):
stats_config = stats["audio_config"]
# check all audio parameters used for computing stats
skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"]
for key in stats_config.keys():
for key in stats_config:
if key in skip_parameters:
continue
if key not in ["sample_rate", "trim_db"]:
@ -415,10 +412,7 @@ class AudioProcessor(object):
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_linear:
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
else:
S = np.abs(D)
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base) if self.do_amp_to_db_linear else np.abs(D)
return self.normalize(S).astype(np.float32)
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
@ -467,8 +461,7 @@ class AudioProcessor(object):
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
mel = self.normalize(S)
return mel
return self.normalize(S)
def _griffin_lim(self, S):
return griffin_lim(
@ -502,7 +495,7 @@ class AudioProcessor(object):
if len(x) % self.hop_length == 0:
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
f0 = compute_f0(
return compute_f0(
x=x,
pitch_fmax=self.pitch_fmax,
pitch_fmin=self.pitch_fmin,
@ -513,8 +506,6 @@ class AudioProcessor(object):
center=True,
)
return f0
### Audio Processing ###
def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int:
"""Find the last point without silence at the end of a audio signal.
@ -537,7 +528,7 @@ class AudioProcessor(object):
)
def trim_silence(self, wav):
"""Trim silent parts with a threshold and 0.01 sec margin"""
"""Trim silent parts with a threshold and 0.01 sec margin."""
return trim_silence(
wav=wav,
sample_rate=self.sample_rate,
@ -572,7 +563,7 @@ class AudioProcessor(object):
return rms_volume_norm(x=x, db_level=db_level)
### save and load ###
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
def load_wav(self, filename: str, sr: Optional[int] = None) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.