mirror of https://github.com/coqui-ai/TTS.git
266 lines
7.9 KiB
Python
266 lines
7.9 KiB
Python
import logging
|
|
|
|
import librosa
|
|
import torch
|
|
from torch import nn
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
hann_window = {}
|
|
mel_basis = {}
|
|
|
|
|
|
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
|
|
"""Spectral normalization / dynamic range compression."""
|
|
return torch.log(torch.clamp(x, min=clip_val) * spec_gain)
|
|
|
|
|
|
def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
|
|
"""Spectral denormalization / dynamic range decompression."""
|
|
return torch.exp(x) / spec_gain
|
|
|
|
|
|
def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor:
|
|
"""
|
|
Args Shapes:
|
|
- y : :math:`[B, 1, T]`
|
|
|
|
Return Shapes:
|
|
- spec : :math:`[B,C,T]`
|
|
"""
|
|
y = y.squeeze(1)
|
|
|
|
if torch.min(y) < -1.0:
|
|
logger.info("min value is %.3f", torch.min(y))
|
|
if torch.max(y) > 1.0:
|
|
logger.info("max value is %.3f", torch.max(y))
|
|
|
|
global hann_window
|
|
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}"
|
|
if wnsize_dtype_device not in hann_window:
|
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
|
|
|
y = torch.nn.functional.pad(
|
|
y.unsqueeze(1),
|
|
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
|
mode="reflect",
|
|
)
|
|
y = y.squeeze(1)
|
|
|
|
spec = torch.view_as_real(
|
|
torch.stft(
|
|
y,
|
|
n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
window=hann_window[wnsize_dtype_device],
|
|
center=center,
|
|
pad_mode="reflect",
|
|
normalized=False,
|
|
onesided=True,
|
|
return_complex=True,
|
|
)
|
|
)
|
|
|
|
return torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
|
|
|
|
def spec_to_mel(
|
|
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args Shapes:
|
|
- spec : :math:`[B,C,T]`
|
|
|
|
Return Shapes:
|
|
- mel : :math:`[B,C,T]`
|
|
"""
|
|
global mel_basis
|
|
fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
|
|
if fmax_dtype_device not in mel_basis:
|
|
# TODO: switch librosa to torchaudio
|
|
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
|
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
|
return amp_to_db(mel)
|
|
|
|
|
|
def wav_to_mel(
|
|
y: torch.Tensor,
|
|
n_fft: int,
|
|
num_mels: int,
|
|
sample_rate: int,
|
|
hop_length: int,
|
|
win_length: int,
|
|
fmin: float,
|
|
fmax: float,
|
|
*,
|
|
center: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args Shapes:
|
|
- y : :math:`[B, 1, T]`
|
|
|
|
Return Shapes:
|
|
- spec : :math:`[B,C,T]`
|
|
"""
|
|
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
|
return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)
|
|
|
|
|
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
|
|
|
Args:
|
|
|
|
n_fft (int):
|
|
FFT window size for STFT.
|
|
|
|
hop_length (int):
|
|
number of frames between STFT columns.
|
|
|
|
win_length (int, optional):
|
|
STFT window length.
|
|
|
|
pad_wav (bool, optional):
|
|
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
|
|
|
window (str, optional):
|
|
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
|
|
|
sample_rate (int, optional):
|
|
target audio sampling rate. Defaults to None.
|
|
|
|
mel_fmin (int, optional):
|
|
minimum filter frequency for computing melspectrograms. Defaults to None.
|
|
|
|
mel_fmax (int, optional):
|
|
maximum filter frequency for computing melspectrograms. Defaults to None.
|
|
|
|
n_mels (int, optional):
|
|
number of melspectrogram dimensions. Defaults to None.
|
|
|
|
use_mel (bool, optional):
|
|
If True compute the melspectrograms otherwise. Defaults to False.
|
|
|
|
do_amp_to_db_linear (bool, optional):
|
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
|
|
|
spec_gain (float, optional):
|
|
gain applied when converting amplitude to DB. Defaults to 1.0.
|
|
|
|
power (float, optional):
|
|
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
|
|
|
use_htk (bool, optional):
|
|
Use HTK formula in mel filter instead of Slaney.
|
|
|
|
mel_norm (None, 'slaney', or number, optional):
|
|
If 'slaney', divide the triangular mel weights by the width of the mel band
|
|
(area normalization).
|
|
|
|
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
|
See `librosa.util.normalize` for a full description of supported norm values
|
|
(including `+-np.inf`).
|
|
|
|
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
n_fft,
|
|
hop_length,
|
|
win_length,
|
|
pad_wav=False,
|
|
window="hann_window",
|
|
sample_rate=None,
|
|
mel_fmin=0,
|
|
mel_fmax=None,
|
|
n_mels=80,
|
|
use_mel=False,
|
|
do_amp_to_db=False,
|
|
spec_gain=1.0,
|
|
power=None,
|
|
use_htk=False,
|
|
mel_norm="slaney",
|
|
normalized=False,
|
|
):
|
|
super().__init__()
|
|
self.n_fft = n_fft
|
|
self.hop_length = hop_length
|
|
self.win_length = win_length
|
|
self.pad_wav = pad_wav
|
|
self.sample_rate = sample_rate
|
|
self.mel_fmin = mel_fmin
|
|
self.mel_fmax = mel_fmax
|
|
self.n_mels = n_mels
|
|
self.use_mel = use_mel
|
|
self.do_amp_to_db = do_amp_to_db
|
|
self.spec_gain = spec_gain
|
|
self.power = power
|
|
self.use_htk = use_htk
|
|
self.mel_norm = mel_norm
|
|
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
|
self.mel_basis = None
|
|
self.normalized = normalized
|
|
if use_mel:
|
|
self._build_mel_basis()
|
|
|
|
def __call__(self, x):
|
|
"""Compute spectrogram frames by torch based stft.
|
|
|
|
Args:
|
|
x (Tensor): input waveform
|
|
|
|
Returns:
|
|
Tensor: spectrogram frames.
|
|
|
|
Shapes:
|
|
x: [B x T] or [:math:`[B, 1, T]`]
|
|
"""
|
|
if x.ndim == 2:
|
|
x = x.unsqueeze(1)
|
|
if self.pad_wav:
|
|
padding = int((self.n_fft - self.hop_length) / 2)
|
|
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
|
# B x D x T x 2
|
|
o = torch.view_as_real(
|
|
torch.stft(
|
|
x.squeeze(1),
|
|
self.n_fft,
|
|
self.hop_length,
|
|
self.win_length,
|
|
self.window,
|
|
center=True,
|
|
pad_mode="reflect", # compatible with audio.py
|
|
normalized=self.normalized,
|
|
onesided=True,
|
|
return_complex=True,
|
|
)
|
|
)
|
|
M = o[:, :, :, 0]
|
|
P = o[:, :, :, 1]
|
|
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
|
|
|
if self.power is not None:
|
|
S = S**self.power
|
|
|
|
if self.use_mel:
|
|
S = torch.matmul(self.mel_basis.to(x), S)
|
|
if self.do_amp_to_db:
|
|
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
|
return S
|
|
|
|
def _build_mel_basis(self):
|
|
mel_basis = librosa.filters.mel(
|
|
sr=self.sample_rate,
|
|
n_fft=self.n_fft,
|
|
n_mels=self.n_mels,
|
|
fmin=self.mel_fmin,
|
|
fmax=self.mel_fmax,
|
|
htk=self.use_htk,
|
|
norm=self.mel_norm,
|
|
)
|
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|