coqui-tts/TTS/utils/audio/torch_transforms.py

266 lines
7.9 KiB
Python

import logging
import librosa
import torch
from torch import nn
logger = logging.getLogger(__name__)
hann_window = {}
mel_basis = {}
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
"""Spectral normalization / dynamic range compression."""
return torch.log(torch.clamp(x, min=clip_val) * spec_gain)
def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
"""Spectral denormalization / dynamic range decompression."""
return torch.exp(x) / spec_gain
def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor:
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
y = y.squeeze(1)
if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))
global hann_window
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}"
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
return torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
def spec_to_mel(
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float
) -> torch.Tensor:
"""
Args Shapes:
- spec : :math:`[B,C,T]`
Return Shapes:
- mel : :math:`[B,C,T]`
"""
global mel_basis
fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
if fmax_dtype_device not in mel_basis:
# TODO: switch librosa to torchaudio
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
return amp_to_db(mel)
def wav_to_mel(
y: torch.Tensor,
n_fft: int,
num_mels: int,
sample_rate: int,
hop_length: int,
win_length: int,
fmin: float,
fmax: float,
*,
center: bool = False,
) -> torch.Tensor:
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
Args:
n_fft (int):
FFT window size for STFT.
hop_length (int):
number of frames between STFT columns.
win_length (int, optional):
STFT window length.
pad_wav (bool, optional):
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
window (str, optional):
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
n_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
use_mel (bool, optional):
If True compute the melspectrograms otherwise. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
spec_gain (float, optional):
gain applied when converting amplitude to DB. Defaults to 1.0.
power (float, optional):
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
use_htk (bool, optional):
Use HTK formula in mel filter instead of Slaney.
mel_norm (None, 'slaney', or number, optional):
If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization).
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
See `librosa.util.normalize` for a full description of supported norm values
(including `+-np.inf`).
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney",
normalized=False,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.power = power
self.use_htk = use_htk
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
self.normalized = normalized
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [:math:`[B, 1, T]`]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.view_as_real(
torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=self.normalized,
onesided=True,
return_complex=True,
)
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()