diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py index 84e9967f..e75ab6c4 100644 --- a/TTS/encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -1,7 +1,7 @@ import torch from torch import nn -# from TTS.utils.audio import TorchSTFT +# from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.encoder.models.base_encoder import BaseEncoder diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 5130ac0b..b9a03af1 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -8,7 +8,7 @@ from torch.nn import functional from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT # pylint: disable=abstract-method diff --git a/TTS/utils/audio/__init__.py b/TTS/utils/audio/__init__.py new file mode 100644 index 00000000..f18f2219 --- /dev/null +++ b/TTS/utils/audio/__init__.py @@ -0,0 +1 @@ +from TTS.utils.audio.processor import AudioProcessor diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py new file mode 100644 index 00000000..f6f03855 --- /dev/null +++ b/TTS/utils/audio/numpy_transforms.py @@ -0,0 +1,425 @@ +from typing import Tuple + +import librosa +import numpy as np +import pyworld as pw +import scipy +import soundfile as sf + +# For using kwargs +# pylint: disable=unused-argument + + +def build_mel_basis( + *, + sample_rate: int = None, + fft_size: int = None, + num_mels: int = None, + mel_fmax: int = None, + mel_fmin: int = None, + **kwargs, +) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if mel_fmax is not None: + assert mel_fmax <= sample_rate // 2 + assert mel_fmax - mel_fmin > 0 + return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax) + + +def millisec_to_length( + *, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs +) -> Tuple[int, int]: + """Compute hop and window length from milliseconds. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = frame_length_ms / frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + win_length = int(frame_length_ms / 1000.0 * sample_rate) + hop_length = int(win_length / float(factor)) + return win_length, hop_length + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) + + +def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Decibels spectrogram. + """ + assert (x < 0).sum() == 0, " [!] Input values must be non-negative." + return gain * _log(np.maximum(1e-8, x), base) + + +# pylint: disable=no-self-use +def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / gain, base) + + +def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -coef], [1], x) + + +def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray: + """Reverse pre-emphasis.""" + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -coef], x) + + +def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + spec (np.ndarray): Normalized full scale linear spectrogram. + + Shapes: + - spec: :math:`[C, T]` + + Returns: + np.ndarray: Normalized melspectrogram. + """ + return np.dot(mel_basis, spec) + + +def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + assert (mel < 0).sum() == 0, " [!] Input values must be non-negative." + inv_mel_basis = np.linalg.pinv(mel_basis) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel)) + + +def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + wav (np.ndarray): Waveform. Shape :math:`[T_wav,]` + + Returns: + np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length` + """ + D = stft(y=wav, **kwargs) + S = np.abs(D) + return S.astype(np.float32) + + +def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + D = stft(y=wav, **kwargs) + S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs) + return S.astype(np.float32) + + +def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = spec.copy() + return griffin_lim(spec=S**power, **kwargs) + + +def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + S = mel.copy() + S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear + return griffin_lim(spec=S**power, **kwargs) + + +### STFT and ISTFT ### +def stft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + pad_mode: str = "reflect", + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa STFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.stft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=fft_size, + hop_length=hop_length, + win_length=win_length, + pad_mode=pad_mode, + window=window, + center=center, + ) + + +def istft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa iSTFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.istft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window) + + +def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray: + angles = np.exp(2j * np.pi * np.random.rand(*spec.shape)) + S_complex = np.abs(spec).astype(np.complex) + y = istft(y=S_complex * angles, **kwargs) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(num_iter): + angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) + y = istft(y=S_complex * angles, **kwargs) + return y + + +def compute_stft_paddings( + *, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs +) -> Tuple[int, int]: + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0] + if not pad_two_sides: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + +def compute_f0( + *, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs +) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + + Returns: + np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length` + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." + + f0, t = pw.dio( + x.astype(np.double), + fs=sample_rate, + f0_ceil=pitch_fmax, + frame_period=1000 * hop_length / sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate) + return f0 + + +### Audio Processing ### +def find_endpoint( + *, + wav: np.ndarray = None, + trim_db: float = -40, + sample_rate: int = None, + min_silence_sec=0.8, + gain: float = None, + base: int = None, + **kwargs, +) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None. + base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10. + + Returns: + int: Last point without silence. + """ + window_length = int(sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = db_to_amp(x=-trim_db, gain=gain, base=base) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + +def trim_silence( + *, + wav: np.ndarray = None, + sample_rate: int = None, + trim_db: float = None, + win_length: int = None, + hop_length: int = None, + **kwargs, +) -> np.ndarray: + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0] + + +def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + coef (float): Coefficient to rescale the maximum value. Defaults to 0.95. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * coef + + +def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray: + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) + return wav * a + + +def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + db_level (float): Target dB level in RMS. Defaults to -27.0. + + Returns: + np.ndarray: RMS normalized waveform. + """ + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = rms_norm(wav=x, db_level=db_level) + return wav + + +def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False. + + Returns: + np.ndarray: Loaded waveform. + """ + if resample: + # loading with resampling. It is significantly slower. + x, _ = librosa.load(filename, sr=sample_rate) + else: + # SF is faster than librosa for loading files + x, _ = sf.read(filename) + return x + + +def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None: + """Save float waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform with float values in range [-1, 1] to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16)) + + +def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray: + mu = 2**mulaw_qc - 1 + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + +def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray: + """Recovers waveform from quantized values.""" + mu = 2**mulaw_qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + +def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray: + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + +def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + quantize_bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2**quantize_bits - 1) / 2 + + +def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray: + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2**quantize_bits - 1) - 1 diff --git a/TTS/utils/audio.py b/TTS/utils/audio/processor.py similarity index 84% rename from TTS/utils/audio.py rename to TTS/utils/audio/processor.py index fc9d1942..5a63b444 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio/processor.py @@ -6,179 +6,14 @@ import pyworld as pw import scipy.io.wavfile import scipy.signal import soundfile as sf -import torch -from torch import nn from TTS.tts.utils.helpers import StandardScaler - -class TorchSTFT(nn.Module): # pylint: disable=abstract-method - """Some of the audio processing funtions using Torch for faster batch processing. - - TODO: Merge this with audio.py - - Args: - - n_fft (int): - FFT window size for STFT. - - hop_length (int): - number of frames between STFT columns. - - win_length (int, optional): - STFT window length. - - pad_wav (bool, optional): - If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. - - window (str, optional): - The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" - - sample_rate (int, optional): - target audio sampling rate. Defaults to None. - - mel_fmin (int, optional): - minimum filter frequency for computing melspectrograms. Defaults to None. - - mel_fmax (int, optional): - maximum filter frequency for computing melspectrograms. Defaults to None. - - n_mels (int, optional): - number of melspectrogram dimensions. Defaults to None. - - use_mel (bool, optional): - If True compute the melspectrograms otherwise. Defaults to False. - - do_amp_to_db_linear (bool, optional): - enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. - - spec_gain (float, optional): - gain applied when converting amplitude to DB. Defaults to 1.0. - - power (float, optional): - Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. - - use_htk (bool, optional): - Use HTK formula in mel filter instead of Slaney. - - mel_norm (None, 'slaney', or number, optional): - If 'slaney', divide the triangular mel weights by the width of the mel band - (area normalization). - - If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. - See `librosa.util.normalize` for a full description of supported norm values - (including `+-np.inf`). - - Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". - """ - - def __init__( - self, - n_fft, - hop_length, - win_length, - pad_wav=False, - window="hann_window", - sample_rate=None, - mel_fmin=0, - mel_fmax=None, - n_mels=80, - use_mel=False, - do_amp_to_db=False, - spec_gain=1.0, - power=None, - use_htk=False, - mel_norm="slaney", - ): - super().__init__() - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.pad_wav = pad_wav - self.sample_rate = sample_rate - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.n_mels = n_mels - self.use_mel = use_mel - self.do_amp_to_db = do_amp_to_db - self.spec_gain = spec_gain - self.power = power - self.use_htk = use_htk - self.mel_norm = mel_norm - self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) - self.mel_basis = None - if use_mel: - self._build_mel_basis() - - def __call__(self, x): - """Compute spectrogram frames by torch based stft. - - Args: - x (Tensor): input waveform - - Returns: - Tensor: spectrogram frames. - - Shapes: - x: [B x T] or [:math:`[B, 1, T]`] - """ - if x.ndim == 2: - x = x.unsqueeze(1) - if self.pad_wav: - padding = int((self.n_fft - self.hop_length) / 2) - x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") - # B x D x T x 2 - o = torch.stft( - x.squeeze(1), - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - pad_mode="reflect", # compatible with audio.py - normalized=False, - onesided=True, - return_complex=False, - ) - M = o[:, :, :, 0] - P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) - - if self.power is not None: - S = S**self.power - - if self.use_mel: - S = torch.matmul(self.mel_basis.to(x), S) - if self.do_amp_to_db: - S = self._amp_to_db(S, spec_gain=self.spec_gain) - return S - - def _build_mel_basis(self): - mel_basis = librosa.filters.mel( - self.sample_rate, - self.n_fft, - n_mels=self.n_mels, - fmin=self.mel_fmin, - fmax=self.mel_fmax, - htk=self.use_htk, - norm=self.mel_norm, - ) - self.mel_basis = torch.from_numpy(mel_basis).float() - - @staticmethod - def _amp_to_db(x, spec_gain=1.0): - return torch.log(torch.clamp(x, min=1e-5) * spec_gain) - - @staticmethod - def _db_to_amp(x, spec_gain=1.0): - return torch.exp(x) / spec_gain - - # pylint: disable=too-many-public-methods -class AudioProcessor(object): - """Audio Processor for TTS used by all the data pipelines. - TODO: Make this a dataclass to replace `BaseAudioConfig`. + +class AudioProcessor(object): + """Audio Processor for TTS. Note: All the class arguments are set to default values to enable a flexible initialization diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py new file mode 100644 index 00000000..d4523ad0 --- /dev/null +++ b/TTS/utils/audio/torch_transforms.py @@ -0,0 +1,163 @@ +import librosa +import torch +from torch import nn + + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=False, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + + if self.power is not None: + S = S**self.power + + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + self.sample_rate, + self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 848e292b..befc43cc 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,7 +4,7 @@ import torch from torch import nn from torch.nn import functional as F -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss ################################# diff --git a/TTS/vocoder/models/univnet_discriminator.py b/TTS/vocoder/models/univnet_discriminator.py index d6b0e5d5..34e2d1c2 100644 --- a/TTS/vocoder/models/univnet_discriminator.py +++ b/TTS/vocoder/models/univnet_discriminator.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch import nn from torch.nn.utils import spectral_norm, weight_norm -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator LRELU_SLOPE = 0.1 diff --git a/tests/aux_tests/test_audio_processor.py b/tests/aux_tests/test_audio_processor.py index 56611692..d01aeffa 100644 --- a/tests/aux_tests/test_audio_processor.py +++ b/tests/aux_tests/test_audio_processor.py @@ -3,7 +3,7 @@ import unittest from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.config import BaseAudioConfig -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor TESTS_PATH = get_tests_path() OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests") diff --git a/tests/aux_tests/test_numpy_transforms.py b/tests/aux_tests/test_numpy_transforms.py new file mode 100644 index 00000000..0c1836b9 --- /dev/null +++ b/tests/aux_tests/test_numpy_transforms.py @@ -0,0 +1,105 @@ +import math +import os +import unittest +from dataclasses import dataclass + +import librosa +import numpy as np +from coqpit import Coqpit + +from tests import get_tests_input_path, get_tests_output_path, get_tests_path +from TTS.utils.audio import numpy_transforms as np_transforms + +TESTS_PATH = get_tests_path() +OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests") +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + +os.makedirs(OUT_PATH, exist_ok=True) + + +# pylint: disable=no-self-use + + +class TestNumpyTransforms(unittest.TestCase): + def setUp(self) -> None: + @dataclass + class AudioConfig(Coqpit): + sample_rate: int = 22050 + fft_size: int = 1024 + num_mels: int = 256 + mel_fmax: int = 1800 + mel_fmin: int = 0 + hop_length: int = 256 + win_length: int = 1024 + pitch_fmax: int = 450 + trim_db: int = -1 + min_silence_sec: float = 0.01 + gain: float = 1.0 + base: float = 10.0 + + self.config = AudioConfig() + self.sample_wav, _ = librosa.load(WAV_FILE, sr=self.config.sample_rate) + + def test_build_mel_basis(self): + """Check if the mel basis is correctly built""" + print(" > Testing mel basis building.") + mel_basis = np_transforms.build_mel_basis(**self.config) + self.assertEqual(mel_basis.shape, (self.config.num_mels, self.config.fft_size // 2 + 1)) + + def test_millisec_to_length(self): + """Check if the conversion from milliseconds to length is correct""" + print(" > Testing millisec to length conversion.") + win_len, hop_len = np_transforms.millisec_to_length( + frame_length_ms=1000, frame_shift_ms=12.5, sample_rate=self.config.sample_rate + ) + self.assertEqual(hop_len, int(12.5 / 1000.0 * self.config.sample_rate)) + self.assertEqual(win_len, self.config.sample_rate) + + def test_amplitude_db_conversion(self): + di = np.random.rand(11) + o1 = np_transforms.amp_to_db(x=di, gain=1.0, base=10) + o2 = np_transforms.db_to_amp(x=o1, gain=1.0, base=10) + np.testing.assert_almost_equal(di, o2, decimal=5) + + def test_preemphasis_deemphasis(self): + di = np.random.rand(11) + o1 = np_transforms.preemphasis(x=di, coeff=0.95) + o2 = np_transforms.deemphasis(x=o1, coeff=0.95) + np.testing.assert_almost_equal(di, o2, decimal=5) + + def test_spec_to_mel(self): + mel_basis = np_transforms.build_mel_basis(**self.config) + spec = np.random.rand(self.config.fft_size // 2 + 1, 20) # [C, T] + mel = np_transforms.spec_to_mel(spec=spec, mel_basis=mel_basis) + self.assertEqual(mel.shape, (self.config.num_mels, 20)) + + def mel_to_spec(self): + mel_basis = np_transforms.build_mel_basis(**self.config) + mel = np.random.rand(self.config.num_mels, 20) # [C, T] + spec = np_transforms.mel_to_spec(mel=mel, mel_basis=mel_basis) + self.assertEqual(spec.shape, (self.config.fft_size // 2 + 1, 20)) + + def test_wav_to_spec(self): + spec = np_transforms.wav_to_spec(wav=self.sample_wav, **self.config) + self.assertEqual( + spec.shape, (self.config.fft_size // 2 + 1, math.ceil(self.sample_wav.shape[0] / self.config.hop_length)) + ) + + def test_wav_to_mel(self): + mel_basis = np_transforms.build_mel_basis(**self.config) + mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config) + self.assertEqual( + mel.shape, (self.config.num_mels, math.ceil(self.sample_wav.shape[0] / self.config.hop_length)) + ) + + def test_compute_f0(self): + pitch = np_transforms.compute_f0(x=self.sample_wav, **self.config) + mel_basis = np_transforms.build_mel_basis(**self.config) + mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config) + assert pitch.shape[0] == mel.shape[1] + + def test_load_wav(self): + wav = np_transforms.load_wav(filename=WAV_FILE, resample=False, sample_rate=22050) + wav_resample = np_transforms.load_wav(filename=WAV_FILE, resample=True, sample_rate=16000) + self.assertEqual(wav.shape, (self.sample_wav.shape[0],)) + self.assertNotEqual(wav_resample.shape, (self.sample_wav.shape[0],))