refactor: move amp_to_db/db_to_amp into torch_transforms

This commit is contained in:
Enno Hermann 2024-11-22 21:30:21 +01:00
parent 33ac0d6ee1
commit 7cdfde226b
9 changed files with 36 additions and 98 deletions

View File

@ -9,7 +9,7 @@ import torch
import torchaudio
from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.utils.audio.torch_transforms import TorchSTFT, amp_to_db
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -88,24 +88,6 @@ def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def get_voices(extra_voice_dirs: List[str] = []):
dirs = extra_voice_dirs
voices: Dict[str, List[str]] = {}
@ -175,7 +157,7 @@ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
)
stft = stft.to(device)
mel = stft(wav)
mel = dynamic_range_compression(mel)
mel = amp_to_db(mel)
if do_normalization:
mel = normalize_tacotron_mel(mel)
return mel

View File

@ -32,6 +32,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
from TTS.utils.audio.processor import AudioProcessor
from TTS.utils.audio.torch_transforms import amp_to_db
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
@ -136,24 +137,6 @@ def load_audio(file_path: str):
return x, sr
def _amp_to_db(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def _db_to_amp(x, C=1):
return torch.exp(x) / C
def amp_to_db(magnitudes):
output = _amp_to_db(magnitudes)
return output
def db_to_amp(magnitudes):
output = _db_to_amp(magnitudes)
return output
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1)

View File

@ -35,6 +35,7 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.audio.torch_transforms import amp_to_db
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
@ -78,24 +79,6 @@ def load_audio(file_path):
return x, sr
def _amp_to_db(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def _db_to_amp(x, C=1):
return torch.exp(x) / C
def amp_to_db(magnitudes):
output = _amp_to_db(magnitudes)
return output
def db_to_amp(magnitudes):
output = _db_to_amp(magnitudes)
return output
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
"""
Args Shapes:

View File

@ -59,7 +59,7 @@ def _exp(x, base):
return np.exp(x)
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: float = 10, **kwargs) -> np.ndarray:
"""Convert amplitude values to decibels.
Args:

View File

@ -3,6 +3,16 @@ import torch
from torch import nn
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
"""Spectral normalization / dynamic range compression."""
return torch.log(torch.clamp(x, min=clip_val) * spec_gain)
def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
"""Spectral denormalization / dynamic range decompression."""
return torch.exp(x) / spec_gain
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
@ -157,11 +167,3 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain

View File

@ -4,39 +4,12 @@ import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from TTS.utils.audio.torch_transforms import amp_to_db
logger = logging.getLogger(__name__)
MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
@ -85,7 +58,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
spec = amp_to_db(spec)
return spec
@ -128,6 +101,6 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size,
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
spec = amp_to_db(spec)
return spec

View File

@ -0,0 +1,16 @@
import numpy as np
import torch
from TTS.utils.audio import numpy_transforms as np_transforms
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
def test_amplitude_db_conversion():
x = torch.rand(11)
o1 = amp_to_db(x=x, spec_gain=1.0)
o2 = db_to_amp(x=o1, spec_gain=1.0)
np_o1 = np_transforms.amp_to_db(x=x, base=np.e)
np_o2 = np_transforms.db_to_amp(x=np_o1, base=np.e)
assert torch.allclose(x, o2)
assert torch.allclose(o1, np_o1)
assert torch.allclose(o2, np_o2)

View File

@ -13,14 +13,13 @@ from TTS.tts.models.vits import (
Vits,
VitsArgs,
VitsAudioConfig,
amp_to_db,
db_to_amp,
load_audio,
spec_to_mel,
wav_to_mel,
wav_to_spec,
)
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")