Introduce numpy and torch transforms (#1705)

* Refactor audio processing functions

* Add tests for numpy transforms

* Fix imports

* Fix imports2
This commit is contained in:
Eren Gölge 2022-08-08 11:57:50 +02:00 committed by GitHub
parent 7fd9b89ebf
commit d46fbc240c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 702 additions and 173 deletions

View File

@ -1,7 +1,7 @@
import torch
from torch import nn
# from TTS.utils.audio import TorchSTFT
# from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.encoder.models.base_encoder import BaseEncoder

View File

@ -8,7 +8,7 @@ from torch.nn import functional
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
from TTS.utils.audio import TorchSTFT
from TTS.utils.audio.torch_transforms import TorchSTFT
# pylint: disable=abstract-method

View File

@ -0,0 +1 @@
from TTS.utils.audio.processor import AudioProcessor

View File

@ -0,0 +1,425 @@
from typing import Tuple
import librosa
import numpy as np
import pyworld as pw
import scipy
import soundfile as sf
# For using kwargs
# pylint: disable=unused-argument
def build_mel_basis(
*,
sample_rate: int = None,
fft_size: int = None,
num_mels: int = None,
mel_fmax: int = None,
mel_fmin: int = None,
**kwargs,
) -> np.ndarray:
"""Build melspectrogram basis.
Returns:
np.ndarray: melspectrogram basis.
"""
if mel_fmax is not None:
assert mel_fmax <= sample_rate // 2
assert mel_fmax - mel_fmin > 0
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
def millisec_to_length(
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
) -> Tuple[int, int]:
"""Compute hop and window length from milliseconds.
Returns:
Tuple[int, int]: hop length and window length for STFT.
"""
factor = frame_length_ms / frame_shift_ms
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
win_length = int(frame_length_ms / 1000.0 * sample_rate)
hop_length = int(win_length / float(factor))
return win_length, hop_length
def _log(x, base):
if base == 10:
return np.log10(x)
return np.log(x)
def _exp(x, base):
if base == 10:
return np.power(10, x)
return np.exp(x)
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
"""Convert amplitude values to decibels.
Args:
x (np.ndarray): Amplitude spectrogram.
gain (float): Gain factor. Defaults to 1.
base (int): Logarithm base. Defaults to 10.
Returns:
np.ndarray: Decibels spectrogram.
"""
assert (x < 0).sum() == 0, " [!] Input values must be non-negative."
return gain * _log(np.maximum(1e-8, x), base)
# pylint: disable=no-self-use
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
"""Convert decibels spectrogram to amplitude spectrogram.
Args:
x (np.ndarray): Decibels spectrogram.
gain (float): Gain factor. Defaults to 1.
base (int): Logarithm base. Defaults to 10.
Returns:
np.ndarray: Amplitude spectrogram.
"""
return _exp(x / gain, base)
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
Args:
x (np.ndarray): Audio signal.
Raises:
RuntimeError: Preemphasis coeff is set to 0.
Returns:
np.ndarray: Decorrelated audio signal.
"""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1, -coef], [1], x)
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
"""Reverse pre-emphasis."""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1], [1, -coef], x)
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
Args:
spec (np.ndarray): Normalized full scale linear spectrogram.
Shapes:
- spec: :math:`[C, T]`
Returns:
np.ndarray: Normalized melspectrogram.
"""
return np.dot(mel_basis, spec)
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to full scale spectrogram."""
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
inv_mel_basis = np.linalg.pinv(mel_basis)
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
"""Compute a spectrogram from a waveform.
Args:
wav (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length`
"""
D = stft(y=wav, **kwargs)
S = np.abs(D)
return S.astype(np.float32)
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
"""Compute a melspectrogram from a waveform."""
D = stft(y=wav, **kwargs)
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
return S.astype(np.float32)
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
S = spec.copy()
return griffin_lim(spec=S**power, **kwargs)
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
S = mel.copy()
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
return griffin_lim(spec=S**power, **kwargs)
### STFT and ISTFT ###
def stft(
*,
y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None,
win_length: int = None,
pad_mode: str = "reflect",
window: str = "hann",
center: bool = True,
**kwargs,
) -> np.ndarray:
"""Librosa STFT wrapper.
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
Returns:
np.ndarray: Complex number array.
"""
return librosa.stft(
y=y,
n_fft=fft_size,
hop_length=hop_length,
win_length=win_length,
pad_mode=pad_mode,
window=window,
center=center,
)
def istft(
*,
y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None,
win_length: int = None,
window: str = "hann",
center: bool = True,
**kwargs,
) -> np.ndarray:
"""Librosa iSTFT wrapper.
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
Returns:
np.ndarray: Complex number array.
"""
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
S_complex = np.abs(spec).astype(np.complex)
y = istft(y=S_complex * angles, **kwargs)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(num_iter):
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
y = istft(y=S_complex * angles, **kwargs)
return y
def compute_stft_paddings(
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
) -> Tuple[int, int]:
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)"""
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
if not pad_two_sides:
return 0, pad
return pad // 2, pad // 2 + pad % 2
def compute_f0(
*, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs
) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file()
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000)
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
>>> pitch = ap.compute_f0(wav)
"""
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
f0, t = pw.dio(
x.astype(np.double),
fs=sample_rate,
f0_ceil=pitch_fmax,
frame_period=1000 * hop_length / sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
return f0
### Audio Processing ###
def find_endpoint(
*,
wav: np.ndarray = None,
trim_db: float = -40,
sample_rate: int = None,
min_silence_sec=0.8,
gain: float = None,
base: int = None,
**kwargs,
) -> int:
"""Find the last point without silence at the end of a audio signal.
Args:
wav (np.ndarray): Audio signal.
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
Returns:
int: Last point without silence.
"""
window_length = int(sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x : x + window_length]) < threshold:
return x + hop_length
return len(wav)
def trim_silence(
*,
wav: np.ndarray = None,
sample_rate: int = None,
trim_db: float = None,
win_length: int = None,
hop_length: int = None,
**kwargs,
) -> np.ndarray:
"""Trim silent parts with a threshold and 0.01 sec margin"""
margin = int(sample_rate * 0.01)
wav = wav[margin:-margin]
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
"""Normalize the volume of an audio signal.
Args:
x (np.ndarray): Raw waveform.
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
Returns:
np.ndarray: Volume normalized waveform.
"""
return x / abs(x).max() * coef
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
"""Normalize the volume based on RMS of the signal.
Args:
x (np.ndarray): Raw waveform.
db_level (float): Target dB level in RMS. Defaults to -27.0.
Returns:
np.ndarray: RMS normalized waveform.
"""
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
wav = rms_norm(wav=x, db_level=db_level)
return wav
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
Args:
filename (str): Path to the wav file.
sr (int, optional): Sampling rate for resampling. Defaults to None.
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
Returns:
np.ndarray: Loaded waveform.
"""
if resample:
# loading with resampling. It is significantly slower.
x, _ = librosa.load(filename, sr=sample_rate)
else:
# SF is faster than librosa for loading files
x, _ = sf.read(filename)
return x
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None:
"""Save float waveform to a file using Scipy.
Args:
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
path (str): Path to a output file.
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
"""
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16))
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
mu = 2**mulaw_qc - 1
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
"""Recovers waveform from quantized values."""
mu = 2**mulaw_qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
quantize_bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**quantize_bits - 1) / 2
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**quantize_bits - 1) - 1

View File

@ -6,179 +6,14 @@ import pyworld as pw
import scipy.io.wavfile
import scipy.signal
import soundfile as sf
import torch
from torch import nn
from TTS.tts.utils.helpers import StandardScaler
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
TODO: Merge this with audio.py
Args:
n_fft (int):
FFT window size for STFT.
hop_length (int):
number of frames between STFT columns.
win_length (int, optional):
STFT window length.
pad_wav (bool, optional):
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
window (str, optional):
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
n_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
use_mel (bool, optional):
If True compute the melspectrograms otherwise. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
spec_gain (float, optional):
gain applied when converting amplitude to DB. Defaults to 1.0.
power (float, optional):
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
use_htk (bool, optional):
Use HTK formula in mel filter instead of Slaney.
mel_norm (None, 'slaney', or number, optional):
If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization).
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
See `librosa.util.normalize` for a full description of supported norm values
(including `+-np.inf`).
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.power = power
self.use_htk = use_htk
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [:math:`[B, 1, T]`]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain
# pylint: disable=too-many-public-methods
class AudioProcessor(object):
"""Audio Processor for TTS used by all the data pipelines.
TODO: Make this a dataclass to replace `BaseAudioConfig`.
class AudioProcessor(object):
"""Audio Processor for TTS.
Note:
All the class arguments are set to default values to enable a flexible initialization

View File

@ -0,0 +1,163 @@
import librosa
import torch
from torch import nn
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
Args:
n_fft (int):
FFT window size for STFT.
hop_length (int):
number of frames between STFT columns.
win_length (int, optional):
STFT window length.
pad_wav (bool, optional):
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
window (str, optional):
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
n_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
use_mel (bool, optional):
If True compute the melspectrograms otherwise. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
spec_gain (float, optional):
gain applied when converting amplitude to DB. Defaults to 1.0.
power (float, optional):
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
use_htk (bool, optional):
Use HTK formula in mel filter instead of Slaney.
mel_norm (None, 'slaney', or number, optional):
If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization).
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
See `librosa.util.normalize` for a full description of supported norm values
(including `+-np.inf`).
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.power = power
self.use_htk = use_htk
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [:math:`[B, 1, T]`]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain

View File

@ -4,7 +4,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from TTS.utils.audio import TorchSTFT
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
#################################

View File

@ -3,7 +3,7 @@ import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm, weight_norm
from TTS.utils.audio import TorchSTFT
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
LRELU_SLOPE = 0.1

View File

@ -3,7 +3,7 @@ import unittest
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.config import BaseAudioConfig
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.processor import AudioProcessor
TESTS_PATH = get_tests_path()
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")

View File

@ -0,0 +1,105 @@
import math
import os
import unittest
from dataclasses import dataclass
import librosa
import numpy as np
from coqpit import Coqpit
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.utils.audio import numpy_transforms as np_transforms
TESTS_PATH = get_tests_path()
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
os.makedirs(OUT_PATH, exist_ok=True)
# pylint: disable=no-self-use
class TestNumpyTransforms(unittest.TestCase):
def setUp(self) -> None:
@dataclass
class AudioConfig(Coqpit):
sample_rate: int = 22050
fft_size: int = 1024
num_mels: int = 256
mel_fmax: int = 1800
mel_fmin: int = 0
hop_length: int = 256
win_length: int = 1024
pitch_fmax: int = 450
trim_db: int = -1
min_silence_sec: float = 0.01
gain: float = 1.0
base: float = 10.0
self.config = AudioConfig()
self.sample_wav, _ = librosa.load(WAV_FILE, sr=self.config.sample_rate)
def test_build_mel_basis(self):
"""Check if the mel basis is correctly built"""
print(" > Testing mel basis building.")
mel_basis = np_transforms.build_mel_basis(**self.config)
self.assertEqual(mel_basis.shape, (self.config.num_mels, self.config.fft_size // 2 + 1))
def test_millisec_to_length(self):
"""Check if the conversion from milliseconds to length is correct"""
print(" > Testing millisec to length conversion.")
win_len, hop_len = np_transforms.millisec_to_length(
frame_length_ms=1000, frame_shift_ms=12.5, sample_rate=self.config.sample_rate
)
self.assertEqual(hop_len, int(12.5 / 1000.0 * self.config.sample_rate))
self.assertEqual(win_len, self.config.sample_rate)
def test_amplitude_db_conversion(self):
di = np.random.rand(11)
o1 = np_transforms.amp_to_db(x=di, gain=1.0, base=10)
o2 = np_transforms.db_to_amp(x=o1, gain=1.0, base=10)
np.testing.assert_almost_equal(di, o2, decimal=5)
def test_preemphasis_deemphasis(self):
di = np.random.rand(11)
o1 = np_transforms.preemphasis(x=di, coeff=0.95)
o2 = np_transforms.deemphasis(x=o1, coeff=0.95)
np.testing.assert_almost_equal(di, o2, decimal=5)
def test_spec_to_mel(self):
mel_basis = np_transforms.build_mel_basis(**self.config)
spec = np.random.rand(self.config.fft_size // 2 + 1, 20) # [C, T]
mel = np_transforms.spec_to_mel(spec=spec, mel_basis=mel_basis)
self.assertEqual(mel.shape, (self.config.num_mels, 20))
def mel_to_spec(self):
mel_basis = np_transforms.build_mel_basis(**self.config)
mel = np.random.rand(self.config.num_mels, 20) # [C, T]
spec = np_transforms.mel_to_spec(mel=mel, mel_basis=mel_basis)
self.assertEqual(spec.shape, (self.config.fft_size // 2 + 1, 20))
def test_wav_to_spec(self):
spec = np_transforms.wav_to_spec(wav=self.sample_wav, **self.config)
self.assertEqual(
spec.shape, (self.config.fft_size // 2 + 1, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
)
def test_wav_to_mel(self):
mel_basis = np_transforms.build_mel_basis(**self.config)
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
self.assertEqual(
mel.shape, (self.config.num_mels, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
)
def test_compute_f0(self):
pitch = np_transforms.compute_f0(x=self.sample_wav, **self.config)
mel_basis = np_transforms.build_mel_basis(**self.config)
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
assert pitch.shape[0] == mel.shape[1]
def test_load_wav(self):
wav = np_transforms.load_wav(filename=WAV_FILE, resample=False, sample_rate=22050)
wav_resample = np_transforms.load_wav(filename=WAV_FILE, resample=True, sample_rate=16000)
self.assertEqual(wav.shape, (self.sample_wav.shape[0],))
self.assertNotEqual(wav_resample.shape, (self.sample_wav.shape[0],))