From 7cdfde226bc03cc792424c4f3a93741150213cfc Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 21:30:21 +0100 Subject: [PATCH] refactor: move amp_to_db/db_to_amp into torch_transforms --- TTS/tts/layers/tortoise/audio_utils.py | 22 ++------------- TTS/tts/models/delightful_tts.py | 19 +------------ TTS/tts/models/vits.py | 19 +------------ TTS/utils/audio/numpy_transforms.py | 2 +- TTS/utils/audio/torch_transforms.py | 18 ++++++------ TTS/vc/modules/freevc/mel_processing.py | 35 +++--------------------- tests/aux_tests/test_stft_torch.py | 0 tests/aux_tests/test_torch_transforms.py | 16 +++++++++++ tests/tts_tests/test_vits.py | 3 +- 9 files changed, 36 insertions(+), 98 deletions(-) delete mode 100644 tests/aux_tests/test_stft_torch.py create mode 100644 tests/aux_tests/test_torch_transforms.py diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index 4f299a8f..c67ee6c4 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -9,7 +9,7 @@ import torch import torchaudio from scipy.io.wavfile import read -from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT, amp_to_db from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -88,24 +88,6 @@ def normalize_tacotron_mel(mel): return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 -def dynamic_range_compression(x, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C - - def get_voices(extra_voice_dirs: List[str] = []): dirs = extra_voice_dirs voices: Dict[str, List[str]] = {} @@ -175,7 +157,7 @@ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): ) stft = stft.to(device) mel = stft(wav) - mel = dynamic_range_compression(mel) + mel = amp_to_db(mel) if do_normalization: mel = normalize_tacotron_mel(mel) return mel diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index c6f15a79..880ea4ae 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -32,6 +32,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.utils.audio.processor import AudioProcessor +from TTS.utils.audio.torch_transforms import amp_to_db from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -136,24 +137,6 @@ def load_audio(file_path: str): return x, sr -def _amp_to_db(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def _db_to_amp(x, C=1): - return torch.exp(x) / C - - -def amp_to_db(magnitudes): - output = _amp_to_db(magnitudes) - return output - - -def db_to_amp(magnitudes): - output = _db_to_amp(magnitudes) - return output - - def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 432b29f5..aea0f4e4 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -35,6 +35,7 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment +from TTS.utils.audio.torch_transforms import amp_to_db from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -78,24 +79,6 @@ def load_audio(file_path): return x, sr -def _amp_to_db(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def _db_to_amp(x, C=1): - return torch.exp(x) / C - - -def amp_to_db(magnitudes): - output = _amp_to_db(magnitudes) - return output - - -def db_to_amp(magnitudes): - output = _db_to_amp(magnitudes) - return output - - def wav_to_spec(y, n_fft, hop_length, win_length, center=False): """ Args Shapes: diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index 203091ea..9c83009b 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -59,7 +59,7 @@ def _exp(x, base): return np.exp(x) -def amp_to_db(*, x: np.ndarray, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: +def amp_to_db(*, x: np.ndarray, gain: float = 1, base: float = 10, **kwargs) -> np.ndarray: """Convert amplitude values to decibels. Args: diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index 632969c5..dda4c0a4 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -3,6 +3,16 @@ import torch from torch import nn +def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor: + """Spectral normalization / dynamic range compression.""" + return torch.log(torch.clamp(x, min=clip_val) * spec_gain) + + +def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor: + """Spectral denormalization / dynamic range decompression.""" + return torch.exp(x) / spec_gain + + class TorchSTFT(nn.Module): # pylint: disable=abstract-method """Some of the audio processing funtions using Torch for faster batch processing. @@ -157,11 +167,3 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method norm=self.mel_norm, ) self.mel_basis = torch.from_numpy(mel_basis).float() - - @staticmethod - def _amp_to_db(x, spec_gain=1.0): - return torch.log(torch.clamp(x, min=1e-5) * spec_gain) - - @staticmethod - def _db_to_amp(x, spec_gain=1.0): - return torch.exp(x) / spec_gain diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index a3e25189..4da5e27c 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -4,39 +4,12 @@ import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn +from TTS.utils.audio.torch_transforms import amp_to_db + logger = logging.getLogger(__name__) MAX_WAV_VALUE = 32768.0 - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - mel_basis = {} hann_window = {} @@ -85,7 +58,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) + spec = amp_to_db(spec) return spec @@ -128,6 +101,6 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) + spec = amp_to_db(spec) return spec diff --git a/tests/aux_tests/test_stft_torch.py b/tests/aux_tests/test_stft_torch.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/aux_tests/test_torch_transforms.py b/tests/aux_tests/test_torch_transforms.py new file mode 100644 index 00000000..2da5a359 --- /dev/null +++ b/tests/aux_tests/test_torch_transforms.py @@ -0,0 +1,16 @@ +import numpy as np +import torch + +from TTS.utils.audio import numpy_transforms as np_transforms +from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp + + +def test_amplitude_db_conversion(): + x = torch.rand(11) + o1 = amp_to_db(x=x, spec_gain=1.0) + o2 = db_to_amp(x=o1, spec_gain=1.0) + np_o1 = np_transforms.amp_to_db(x=x, base=np.e) + np_o2 = np_transforms.db_to_amp(x=np_o1, base=np.e) + assert torch.allclose(x, o2) + assert torch.allclose(o1, np_o1) + assert torch.allclose(o2, np_o2) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 17992773..a27bdfe5 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -13,14 +13,13 @@ from TTS.tts.models.vits import ( Vits, VitsArgs, VitsAudioConfig, - amp_to_db, - db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec, ) from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")