mirror of https://github.com/coqui-ai/TTS.git
refactor: move amp_to_db/db_to_amp into torch_transforms
This commit is contained in:
parent
33ac0d6ee1
commit
7cdfde226b
|
@ -9,7 +9,7 @@ import torch
|
|||
import torchaudio
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT, amp_to_db
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -88,24 +88,6 @@ def normalize_tacotron_mel(mel):
|
|||
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def get_voices(extra_voice_dirs: List[str] = []):
|
||||
dirs = extra_voice_dirs
|
||||
voices: Dict[str, List[str]] = {}
|
||||
|
@ -175,7 +157,7 @@ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
|
|||
)
|
||||
stft = stft.to(device)
|
||||
mel = stft(wav)
|
||||
mel = dynamic_range_compression(mel)
|
||||
mel = amp_to_db(mel)
|
||||
if do_normalization:
|
||||
mel = normalize_tacotron_mel(mel)
|
||||
return mel
|
||||
|
|
|
@ -32,6 +32,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
|||
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
||||
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -136,24 +137,6 @@ def load_audio(file_path: str):
|
|||
return x, sr
|
||||
|
||||
|
||||
def _amp_to_db(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def _db_to_amp(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def amp_to_db(magnitudes):
|
||||
output = _amp_to_db(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def db_to_amp(magnitudes):
|
||||
output = _db_to_amp(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
y = y.squeeze(1)
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||
from TTS.utils.samplers import BucketBatchSampler
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -78,24 +79,6 @@ def load_audio(file_path):
|
|||
return x, sr
|
||||
|
||||
|
||||
def _amp_to_db(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def _db_to_amp(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def amp_to_db(magnitudes):
|
||||
output = _amp_to_db(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def db_to_amp(magnitudes):
|
||||
output = _db_to_amp(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
|
|
|
@ -59,7 +59,7 @@ def _exp(x, base):
|
|||
return np.exp(x)
|
||||
|
||||
|
||||
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: float = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -3,6 +3,16 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
|
||||
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
|
||||
"""Spectral normalization / dynamic range compression."""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * spec_gain)
|
||||
|
||||
|
||||
def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
|
||||
"""Spectral denormalization / dynamic range decompression."""
|
||||
return torch.exp(x) / spec_gain
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
|
@ -157,11 +167,3 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
||||
|
|
|
@ -4,39 +4,12 @@ import torch
|
|||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
@ -85,7 +58,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
|||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
spec = amp_to_db(spec)
|
||||
return spec
|
||||
|
||||
|
||||
|
@ -128,6 +101,6 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size,
|
|||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
spec = amp_to_db(spec)
|
||||
|
||||
return spec
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.utils.audio import numpy_transforms as np_transforms
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
|
||||
|
||||
|
||||
def test_amplitude_db_conversion():
|
||||
x = torch.rand(11)
|
||||
o1 = amp_to_db(x=x, spec_gain=1.0)
|
||||
o2 = db_to_amp(x=o1, spec_gain=1.0)
|
||||
np_o1 = np_transforms.amp_to_db(x=x, base=np.e)
|
||||
np_o2 = np_transforms.db_to_amp(x=np_o1, base=np.e)
|
||||
assert torch.allclose(x, o2)
|
||||
assert torch.allclose(o1, np_o1)
|
||||
assert torch.allclose(o2, np_o2)
|
|
@ -13,14 +13,13 @@ from TTS.tts.models.vits import (
|
|||
Vits,
|
||||
VitsArgs,
|
||||
VitsAudioConfig,
|
||||
amp_to_db,
|
||||
db_to_amp,
|
||||
load_audio,
|
||||
spec_to_mel,
|
||||
wav_to_mel,
|
||||
wav_to_spec,
|
||||
)
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||
|
|
Loading…
Reference in New Issue