mirror of https://github.com/coqui-ai/TTS.git
Add numpy and torch transforms
This commit is contained in:
parent
c3fb49bf76
commit
6a53b77a95
|
@ -0,0 +1,452 @@
|
|||
from typing import Callable, Tuple
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
import soundfile as sf
|
||||
import pyworld as pw
|
||||
import scipy
|
||||
|
||||
# from TTS.tts.utils.helpers import StandardScaler
|
||||
|
||||
|
||||
def build_mel_basis(
|
||||
*,
|
||||
sample_rate: int = None,
|
||||
fft_size: int = None,
|
||||
num_mels: int = None,
|
||||
mel_fmax: int = None,
|
||||
mel_fmin: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Build melspectrogram basis.
|
||||
|
||||
Returns:
|
||||
np.ndarray: melspectrogram basis.
|
||||
"""
|
||||
if mel_fmax is not None:
|
||||
assert mel_fmax <= sample_rate // 2
|
||||
assert mel_fmax - mel_fmin > 0
|
||||
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
|
||||
|
||||
|
||||
def millisec_to_length(
|
||||
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute hop and window length from milliseconds.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: hop length and window length for STFT.
|
||||
"""
|
||||
factor = frame_length_ms / frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
hop_length = int(frame_shift_ms / 1000.0 * sample_rate)
|
||||
win_length = int(hop_length * factor)
|
||||
return hop_length, win_length
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
if base == 10:
|
||||
return np.log10(x)
|
||||
return np.log(x)
|
||||
|
||||
|
||||
def _exp(x, base):
|
||||
if base == 10:
|
||||
return np.power(10, x)
|
||||
return np.exp(x)
|
||||
|
||||
|
||||
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Amplitude spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
return gain * _log(np.maximum(1e-5, x), base)
|
||||
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Decibels spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Amplitude spectrogram.
|
||||
"""
|
||||
return _exp(x / gain, base)
|
||||
|
||||
|
||||
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Audio signal.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Preemphasis coeff is set to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1, -coef], [1], x)
|
||||
|
||||
|
||||
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1], [1, -coef], x)
|
||||
|
||||
|
||||
def spec_to_mel(*, spectrogram: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Project a full scale spectrogram to a melspectrogram.
|
||||
|
||||
Args:
|
||||
spectrogram (np.ndarray): Full scale spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Melspectrogram
|
||||
"""
|
||||
return np.dot(mel_basis, spectrogram)
|
||||
|
||||
|
||||
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to full scale spectrogram."""
|
||||
inv_mel_basis = np.linalg.pinv(mel_basis)
|
||||
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
|
||||
|
||||
|
||||
def wav_to_spec(*, y: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Spectrogram.
|
||||
"""
|
||||
D = stft(y, **kwargs)
|
||||
S = np.abs(D)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def wav_to_mel(*, y: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
D = stft(y=y, **kwargs)
|
||||
S = spec_to_mel(spec=np.abs(D), **kwargs)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = spec.copy()
|
||||
if denorm_func is not None:
|
||||
S = denorm_func(spec=S, **kwargs)
|
||||
S = db_to_amp(S)
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = mel.copy()
|
||||
if denorm_func is not None:
|
||||
S = denorm_func(spec=S, **kwargs)
|
||||
S = db_to_amp(S)
|
||||
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
||||
Args:
|
||||
linear_spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
return np.dot(mel_basis, spec)
|
||||
|
||||
|
||||
### STFT and ISTFT ###
|
||||
def stft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
pad_mode: str = "reflect",
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa STFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.stft(
|
||||
y=y,
|
||||
n_fft=fft_size,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
pad_mode=pad_mode,
|
||||
window="hann",
|
||||
center=True,
|
||||
)
|
||||
|
||||
|
||||
def istft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa iSTFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
|
||||
|
||||
|
||||
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
|
||||
S_complex = np.abs(spec).astype(np.complex)
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
if not np.isfinite(y).all():
|
||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||
return np.array([0.0])
|
||||
for _ in range(num_iter):
|
||||
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
return y
|
||||
|
||||
|
||||
def compute_stft_paddings(
|
||||
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||
(first and final frames)"""
|
||||
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
|
||||
if not pad_two_sides:
|
||||
return 0, pad
|
||||
return pad // 2, pad // 2 + pad % 2
|
||||
|
||||
|
||||
def compute_f0(
|
||||
*, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs
|
||||
) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Pitch.
|
||||
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000)
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||
|
||||
f0, t = pw.dio(
|
||||
x.astype(np.double),
|
||||
fs=sample_rate,
|
||||
f0_ceil=pitch_fmax,
|
||||
frame_period=1000 * hop_length / sample_rate,
|
||||
)
|
||||
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
|
||||
return f0
|
||||
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
trim_db: float = None,
|
||||
sample_rate: int = None,
|
||||
min_silence_sec=0.8,
|
||||
gain: float = None,
|
||||
base: int = 10,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""Find the last point without silence at the end of a audio signal.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Audio signal.
|
||||
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
window_length = int(sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x : x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
|
||||
|
||||
def trim_silence(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
sample_rate: int = None,
|
||||
trim_db: float = None,
|
||||
win_length: int = None,
|
||||
hop_length: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
margin = int(sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
|
||||
|
||||
|
||||
def sound_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume of an audio signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return x / abs(x).max() * coef
|
||||
|
||||
|
||||
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||
return wav * a
|
||||
|
||||
|
||||
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
db_level (float): Target dB level in RMS. Defaults to -27.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: RMS normalized waveform.
|
||||
"""
|
||||
if db_level is None:
|
||||
db_level = db_level
|
||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||
wav = rms_norm(wav=x, db_level=db_level)
|
||||
return wav
|
||||
|
||||
|
||||
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if resample:
|
||||
# loading with resampling. It is significantly slower.
|
||||
x, sr = librosa.load(filename, sr=sample_rate)
|
||||
elif sr is None:
|
||||
# SF is faster than librosa for loading files
|
||||
x, sr = sf.read(filename)
|
||||
assert sample_rate == sr, "%s vs %s" % (sample_rate, sr)
|
||||
else:
|
||||
x, sr = librosa.load(filename, sr=sr)
|
||||
return x
|
||||
|
||||
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None:
|
||||
"""Save float waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16))
|
||||
|
||||
|
||||
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
mu = 2**mulaw_qc - 1
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
|
||||
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2**mulaw_qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
|
||||
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
|
||||
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
quantize_bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2**quantize_bits - 1) / 2
|
||||
|
||||
|
||||
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2**quantize_bits - 1) - 1
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
if base == 10:
|
||||
return np.log10(x)
|
||||
return np.log(x)
|
||||
|
||||
|
||||
def _exp(x, base):
|
||||
if base == 10:
|
||||
return np.power(10, x)
|
||||
return np.exp(x)
|
|
@ -0,0 +1,165 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import librosa
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
TODO: Merge this with audio.py
|
||||
|
||||
Args:
|
||||
|
||||
n_fft (int):
|
||||
FFT window size for STFT.
|
||||
|
||||
hop_length (int):
|
||||
number of frames between STFT columns.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length.
|
||||
|
||||
pad_wav (bool, optional):
|
||||
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||
|
||||
window (str, optional):
|
||||
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
n_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
use_mel (bool, optional):
|
||||
If True compute the melspectrograms otherwise. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||
|
||||
spec_gain (float, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||
|
||||
power (float, optional):
|
||||
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||
|
||||
use_htk (bool, optional):
|
||||
Use HTK formula in mel filter instead of Slaney.
|
||||
|
||||
mel_norm (None, 'slaney', or number, optional):
|
||||
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||
(area normalization).
|
||||
|
||||
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||
See `librosa.util.normalize` for a full description of supported norm values
|
||||
(including `+-np.inf`).
|
||||
|
||||
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.power = power
|
||||
self.use_htk = use_htk
|
||||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [:math:`[B, 1, T]`]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S**self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate,
|
||||
self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
htk=self.use_htk,
|
||||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
Loading…
Reference in New Issue