Introduce numpy and torch transforms (#1705)

* Refactor audio processing functions

* Add tests for numpy transforms

* Fix imports

* Fix imports2
This commit is contained in:
Eren Gölge 2022-08-08 11:57:50 +02:00 committed by GitHub
parent 7fd9b89ebf
commit d46fbc240c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 702 additions and 173 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn from torch import nn
# from TTS.utils.audio import TorchSTFT # from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.encoder.models.base_encoder import BaseEncoder from TTS.encoder.models.base_encoder import BaseEncoder

View File

@ -8,7 +8,7 @@ from torch.nn import functional
from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
from TTS.utils.audio import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
# pylint: disable=abstract-method # pylint: disable=abstract-method

View File

@ -0,0 +1 @@
from TTS.utils.audio.processor import AudioProcessor

View File

@ -0,0 +1,425 @@
from typing import Tuple
import librosa
import numpy as np
import pyworld as pw
import scipy
import soundfile as sf
# For using kwargs
# pylint: disable=unused-argument
def build_mel_basis(
*,
sample_rate: int = None,
fft_size: int = None,
num_mels: int = None,
mel_fmax: int = None,
mel_fmin: int = None,
**kwargs,
) -> np.ndarray:
"""Build melspectrogram basis.
Returns:
np.ndarray: melspectrogram basis.
"""
if mel_fmax is not None:
assert mel_fmax <= sample_rate // 2
assert mel_fmax - mel_fmin > 0
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
def millisec_to_length(
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
) -> Tuple[int, int]:
"""Compute hop and window length from milliseconds.
Returns:
Tuple[int, int]: hop length and window length for STFT.
"""
factor = frame_length_ms / frame_shift_ms
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
win_length = int(frame_length_ms / 1000.0 * sample_rate)
hop_length = int(win_length / float(factor))
return win_length, hop_length
def _log(x, base):
if base == 10:
return np.log10(x)
return np.log(x)
def _exp(x, base):
if base == 10:
return np.power(10, x)
return np.exp(x)
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
"""Convert amplitude values to decibels.
Args:
x (np.ndarray): Amplitude spectrogram.
gain (float): Gain factor. Defaults to 1.
base (int): Logarithm base. Defaults to 10.
Returns:
np.ndarray: Decibels spectrogram.
"""
assert (x < 0).sum() == 0, " [!] Input values must be non-negative."
return gain * _log(np.maximum(1e-8, x), base)
# pylint: disable=no-self-use
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
"""Convert decibels spectrogram to amplitude spectrogram.
Args:
x (np.ndarray): Decibels spectrogram.
gain (float): Gain factor. Defaults to 1.
base (int): Logarithm base. Defaults to 10.
Returns:
np.ndarray: Amplitude spectrogram.
"""
return _exp(x / gain, base)
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
Args:
x (np.ndarray): Audio signal.
Raises:
RuntimeError: Preemphasis coeff is set to 0.
Returns:
np.ndarray: Decorrelated audio signal.
"""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1, -coef], [1], x)
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
"""Reverse pre-emphasis."""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1], [1, -coef], x)
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
Args:
spec (np.ndarray): Normalized full scale linear spectrogram.
Shapes:
- spec: :math:`[C, T]`
Returns:
np.ndarray: Normalized melspectrogram.
"""
return np.dot(mel_basis, spec)
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to full scale spectrogram."""
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
inv_mel_basis = np.linalg.pinv(mel_basis)
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
"""Compute a spectrogram from a waveform.
Args:
wav (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length`
"""
D = stft(y=wav, **kwargs)
S = np.abs(D)
return S.astype(np.float32)
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
"""Compute a melspectrogram from a waveform."""
D = stft(y=wav, **kwargs)
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
return S.astype(np.float32)
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
S = spec.copy()
return griffin_lim(spec=S**power, **kwargs)
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
S = mel.copy()
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
return griffin_lim(spec=S**power, **kwargs)
### STFT and ISTFT ###
def stft(
*,
y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None,
win_length: int = None,
pad_mode: str = "reflect",
window: str = "hann",
center: bool = True,
**kwargs,
) -> np.ndarray:
"""Librosa STFT wrapper.
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
Returns:
np.ndarray: Complex number array.
"""
return librosa.stft(
y=y,
n_fft=fft_size,
hop_length=hop_length,
win_length=win_length,
pad_mode=pad_mode,
window=window,
center=center,
)
def istft(
*,
y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None,
win_length: int = None,
window: str = "hann",
center: bool = True,
**kwargs,
) -> np.ndarray:
"""Librosa iSTFT wrapper.
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
Returns:
np.ndarray: Complex number array.
"""
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
S_complex = np.abs(spec).astype(np.complex)
y = istft(y=S_complex * angles, **kwargs)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(num_iter):
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
y = istft(y=S_complex * angles, **kwargs)
return y
def compute_stft_paddings(
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
) -> Tuple[int, int]:
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)"""
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
if not pad_two_sides:
return 0, pad
return pad // 2, pad // 2 + pad % 2
def compute_f0(
*, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs
) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file()
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000)
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
>>> pitch = ap.compute_f0(wav)
"""
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
f0, t = pw.dio(
x.astype(np.double),
fs=sample_rate,
f0_ceil=pitch_fmax,
frame_period=1000 * hop_length / sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
return f0
### Audio Processing ###
def find_endpoint(
*,
wav: np.ndarray = None,
trim_db: float = -40,
sample_rate: int = None,
min_silence_sec=0.8,
gain: float = None,
base: int = None,
**kwargs,
) -> int:
"""Find the last point without silence at the end of a audio signal.
Args:
wav (np.ndarray): Audio signal.
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
Returns:
int: Last point without silence.
"""
window_length = int(sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x : x + window_length]) < threshold:
return x + hop_length
return len(wav)
def trim_silence(
*,
wav: np.ndarray = None,
sample_rate: int = None,
trim_db: float = None,
win_length: int = None,
hop_length: int = None,
**kwargs,
) -> np.ndarray:
"""Trim silent parts with a threshold and 0.01 sec margin"""
margin = int(sample_rate * 0.01)
wav = wav[margin:-margin]
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
"""Normalize the volume of an audio signal.
Args:
x (np.ndarray): Raw waveform.
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
Returns:
np.ndarray: Volume normalized waveform.
"""
return x / abs(x).max() * coef
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
"""Normalize the volume based on RMS of the signal.
Args:
x (np.ndarray): Raw waveform.
db_level (float): Target dB level in RMS. Defaults to -27.0.
Returns:
np.ndarray: RMS normalized waveform.
"""
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
wav = rms_norm(wav=x, db_level=db_level)
return wav
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
Args:
filename (str): Path to the wav file.
sr (int, optional): Sampling rate for resampling. Defaults to None.
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
Returns:
np.ndarray: Loaded waveform.
"""
if resample:
# loading with resampling. It is significantly slower.
x, _ = librosa.load(filename, sr=sample_rate)
else:
# SF is faster than librosa for loading files
x, _ = sf.read(filename)
return x
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None:
"""Save float waveform to a file using Scipy.
Args:
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
path (str): Path to a output file.
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
"""
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16))
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
mu = 2**mulaw_qc - 1
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
"""Recovers waveform from quantized values."""
mu = 2**mulaw_qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
quantize_bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**quantize_bits - 1) / 2
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**quantize_bits - 1) - 1

View File

@ -6,179 +6,14 @@ import pyworld as pw
import scipy.io.wavfile import scipy.io.wavfile
import scipy.signal import scipy.signal
import soundfile as sf import soundfile as sf
import torch
from torch import nn
from TTS.tts.utils.helpers import StandardScaler from TTS.tts.utils.helpers import StandardScaler
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
TODO: Merge this with audio.py
Args:
n_fft (int):
FFT window size for STFT.
hop_length (int):
number of frames between STFT columns.
win_length (int, optional):
STFT window length.
pad_wav (bool, optional):
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
window (str, optional):
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
n_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
use_mel (bool, optional):
If True compute the melspectrograms otherwise. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
spec_gain (float, optional):
gain applied when converting amplitude to DB. Defaults to 1.0.
power (float, optional):
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
use_htk (bool, optional):
Use HTK formula in mel filter instead of Slaney.
mel_norm (None, 'slaney', or number, optional):
If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization).
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
See `librosa.util.normalize` for a full description of supported norm values
(including `+-np.inf`).
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.power = power
self.use_htk = use_htk
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [:math:`[B, 1, T]`]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
class AudioProcessor(object):
"""Audio Processor for TTS used by all the data pipelines.
TODO: Make this a dataclass to replace `BaseAudioConfig`.
class AudioProcessor(object):
"""Audio Processor for TTS.
Note: Note:
All the class arguments are set to default values to enable a flexible initialization All the class arguments are set to default values to enable a flexible initialization

View File

@ -0,0 +1,163 @@
import librosa
import torch
from torch import nn
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
Args:
n_fft (int):
FFT window size for STFT.
hop_length (int):
number of frames between STFT columns.
win_length (int, optional):
STFT window length.
pad_wav (bool, optional):
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
window (str, optional):
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
n_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
use_mel (bool, optional):
If True compute the melspectrograms otherwise. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
spec_gain (float, optional):
gain applied when converting amplitude to DB. Defaults to 1.0.
power (float, optional):
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
use_htk (bool, optional):
Use HTK formula in mel filter instead of Slaney.
mel_norm (None, 'slaney', or number, optional):
If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization).
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
See `librosa.util.normalize` for a full description of supported norm values
(including `+-np.inf`).
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.power = power
self.use_htk = use_htk
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [:math:`[B, 1, T]`]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain

View File

@ -4,7 +4,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from TTS.utils.audio import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
################################# #################################

View File

@ -3,7 +3,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn.utils import spectral_norm, weight_norm from torch.nn.utils import spectral_norm, weight_norm
from TTS.utils.audio import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1

View File

@ -3,7 +3,7 @@ import unittest
from tests import get_tests_input_path, get_tests_output_path, get_tests_path from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.config import BaseAudioConfig from TTS.config import BaseAudioConfig
from TTS.utils.audio import AudioProcessor from TTS.utils.audio.processor import AudioProcessor
TESTS_PATH = get_tests_path() TESTS_PATH = get_tests_path()
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests") OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")

View File

@ -0,0 +1,105 @@
import math
import os
import unittest
from dataclasses import dataclass
import librosa
import numpy as np
from coqpit import Coqpit
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.utils.audio import numpy_transforms as np_transforms
TESTS_PATH = get_tests_path()
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
os.makedirs(OUT_PATH, exist_ok=True)
# pylint: disable=no-self-use
class TestNumpyTransforms(unittest.TestCase):
def setUp(self) -> None:
@dataclass
class AudioConfig(Coqpit):
sample_rate: int = 22050
fft_size: int = 1024
num_mels: int = 256
mel_fmax: int = 1800
mel_fmin: int = 0
hop_length: int = 256
win_length: int = 1024
pitch_fmax: int = 450
trim_db: int = -1
min_silence_sec: float = 0.01
gain: float = 1.0
base: float = 10.0
self.config = AudioConfig()
self.sample_wav, _ = librosa.load(WAV_FILE, sr=self.config.sample_rate)
def test_build_mel_basis(self):
"""Check if the mel basis is correctly built"""
print(" > Testing mel basis building.")
mel_basis = np_transforms.build_mel_basis(**self.config)
self.assertEqual(mel_basis.shape, (self.config.num_mels, self.config.fft_size // 2 + 1))
def test_millisec_to_length(self):
"""Check if the conversion from milliseconds to length is correct"""
print(" > Testing millisec to length conversion.")
win_len, hop_len = np_transforms.millisec_to_length(
frame_length_ms=1000, frame_shift_ms=12.5, sample_rate=self.config.sample_rate
)
self.assertEqual(hop_len, int(12.5 / 1000.0 * self.config.sample_rate))
self.assertEqual(win_len, self.config.sample_rate)
def test_amplitude_db_conversion(self):
di = np.random.rand(11)
o1 = np_transforms.amp_to_db(x=di, gain=1.0, base=10)
o2 = np_transforms.db_to_amp(x=o1, gain=1.0, base=10)
np.testing.assert_almost_equal(di, o2, decimal=5)
def test_preemphasis_deemphasis(self):
di = np.random.rand(11)
o1 = np_transforms.preemphasis(x=di, coeff=0.95)
o2 = np_transforms.deemphasis(x=o1, coeff=0.95)
np.testing.assert_almost_equal(di, o2, decimal=5)
def test_spec_to_mel(self):
mel_basis = np_transforms.build_mel_basis(**self.config)
spec = np.random.rand(self.config.fft_size // 2 + 1, 20) # [C, T]
mel = np_transforms.spec_to_mel(spec=spec, mel_basis=mel_basis)
self.assertEqual(mel.shape, (self.config.num_mels, 20))
def mel_to_spec(self):
mel_basis = np_transforms.build_mel_basis(**self.config)
mel = np.random.rand(self.config.num_mels, 20) # [C, T]
spec = np_transforms.mel_to_spec(mel=mel, mel_basis=mel_basis)
self.assertEqual(spec.shape, (self.config.fft_size // 2 + 1, 20))
def test_wav_to_spec(self):
spec = np_transforms.wav_to_spec(wav=self.sample_wav, **self.config)
self.assertEqual(
spec.shape, (self.config.fft_size // 2 + 1, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
)
def test_wav_to_mel(self):
mel_basis = np_transforms.build_mel_basis(**self.config)
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
self.assertEqual(
mel.shape, (self.config.num_mels, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
)
def test_compute_f0(self):
pitch = np_transforms.compute_f0(x=self.sample_wav, **self.config)
mel_basis = np_transforms.build_mel_basis(**self.config)
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
assert pitch.shape[0] == mel.shape[1]
def test_load_wav(self):
wav = np_transforms.load_wav(filename=WAV_FILE, resample=False, sample_rate=22050)
wav_resample = np_transforms.load_wav(filename=WAV_FILE, resample=True, sample_rate=16000)
self.assertEqual(wav.shape, (self.sample_wav.shape[0],))
self.assertNotEqual(wav_resample.shape, (self.sample_wav.shape[0],))