refactor: move more audio processing into torch_transforms

This commit is contained in:
Enno Hermann 2024-11-23 01:16:50 +01:00
parent 2c82477a78
commit 76df6421de
5 changed files with 100 additions and 314 deletions

View File

@ -9,7 +9,6 @@ import numpy as np
import torch
import torch.distributed as dist
from coqpit import Coqpit
from librosa.filters import mel as librosa_mel_fn
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
@ -38,7 +37,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
from TTS.utils.audio.processor import AudioProcessor
from TTS.utils.audio.torch_transforms import amp_to_db
from TTS.utils.audio.torch_transforms import wav_to_mel, wav_to_spec
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
@ -50,145 +49,11 @@ hann_window = {}
mel_basis = {}
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1)
if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))
global hann_window # pylint: disable=global-statement
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_length) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
return spec
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def wav_to_energy(y, n_fft, hop_length, win_length, center=False):
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
return torch.norm(spec, dim=1, keepdim=True)
def name_mel_basis(spec, n_fft, fmax):
n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
return n_fft_len
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
"""
Args Shapes:
- spec : :math:`[B,C,T]`
Return Shapes:
- mel : :math:`[B,C,T]`
"""
global mel_basis # pylint: disable=global-statement
mel_basis_key = name_mel_basis(spec, n_fft, fmax)
# pylint: disable=too-many-function-args
if mel_basis_key not in mel_basis:
# pylint: disable=missing-kwoa
mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax)
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[mel_basis_key], spec)
mel = amp_to_db(mel)
return mel
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
"""
Args Shapes:
- y : :math:`[B, 1, T_y]`
Return Shapes:
- spec : :math:`[B,C,T_spec]`
"""
y = y.squeeze(1)
if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
mel_basis_key = name_mel_basis(y, n_fft, fmax)
wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device)
if mel_basis_key not in mel_basis:
# pylint: disable=missing-kwoa
mel = librosa_mel_fn(
sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
) # pylint: disable=too-many-function-args
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.matmul(mel_basis[mel_basis_key], spec)
spec = amp_to_db(spec)
return spec
##############################
# DATASET
##############################

View File

@ -10,7 +10,6 @@ import torch
import torch.distributed as dist
import torchaudio
from coqpit import Coqpit
from librosa.filters import mel as librosa_mel_fn
from monotonic_alignment_search import maximum_path
from torch import nn
from torch.nn import functional as F
@ -35,7 +34,7 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.audio.torch_transforms import amp_to_db
from TTS.utils.audio.torch_transforms import spec_to_mel, wav_to_mel, wav_to_spec
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
@ -46,10 +45,6 @@ logger = logging.getLogger(__name__)
# IO / Feature extraction
##############################
# pylint: disable=global-statement
hann_window = {}
mel_basis = {}
@torch.no_grad()
def weights_reset(m: nn.Module):
@ -79,125 +74,6 @@ def load_audio(file_path):
return x, sr
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
y = y.squeeze(1)
if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_length) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
"""
Args Shapes:
- spec : :math:`[B,C,T]`
Return Shapes:
- mel : :math:`[B,C,T]`
"""
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
mel = amp_to_db(mel)
return mel
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
y = y.squeeze(1)
if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_length) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = amp_to_db(spec)
return spec
#############################
# CONFIGS
#############################

View File

@ -1,7 +1,15 @@
import logging
import librosa
import torch
from torch import nn
logger = logging.getLogger(__name__)
hann_window = {}
mel_basis = {}
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
"""Spectral normalization / dynamic range compression."""
@ -13,6 +21,94 @@ def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
return torch.exp(x) / spec_gain
def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor:
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
y = y.squeeze(1)
if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))
global hann_window
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}"
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
return torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
def spec_to_mel(
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float
) -> torch.Tensor:
"""
Args Shapes:
- spec : :math:`[B,C,T]`
Return Shapes:
- mel : :math:`[B,C,T]`
"""
global mel_basis
fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
if fmax_dtype_device not in mel_basis:
# TODO: switch librosa to torchaudio
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
return amp_to_db(mel)
def wav_to_mel(
y: torch.Tensor,
n_fft: int,
num_mels: int,
sample_rate: int,
hop_length: int,
win_length: int,
fmin: float,
fmax: float,
*,
center: bool = False,
) -> torch.Tensor:
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.

View File

@ -14,54 +14,6 @@ mel_basis = {}
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0:
logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("Max value is: %.3f", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = amp_to_db(spec)
return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
logger.info("Min value is: %.3f", torch.min(y))

View File

@ -14,12 +14,9 @@ from TTS.tts.models.vits import (
VitsArgs,
VitsAudioConfig,
load_audio,
spec_to_mel,
wav_to_mel,
wav_to_spec,
)
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp, spec_to_mel, wav_to_mel, wav_to_spec
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")