mirror of https://github.com/coqui-ai/TTS.git
refactor: move amp_to_db/db_to_amp into torch_transforms
This commit is contained in:
parent
33ac0d6ee1
commit
7cdfde226b
|
@ -9,7 +9,7 @@ import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from scipy.io.wavfile import read
|
from scipy.io.wavfile import read
|
||||||
|
|
||||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
from TTS.utils.audio.torch_transforms import TorchSTFT, amp_to_db
|
||||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -88,24 +88,6 @@ def normalize_tacotron_mel(mel):
|
||||||
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||||
|
|
||||||
|
|
||||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
|
||||||
"""
|
|
||||||
PARAMS
|
|
||||||
------
|
|
||||||
C: compression factor
|
|
||||||
"""
|
|
||||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
||||||
|
|
||||||
|
|
||||||
def dynamic_range_decompression(x, C=1):
|
|
||||||
"""
|
|
||||||
PARAMS
|
|
||||||
------
|
|
||||||
C: compression factor used to compress
|
|
||||||
"""
|
|
||||||
return torch.exp(x) / C
|
|
||||||
|
|
||||||
|
|
||||||
def get_voices(extra_voice_dirs: List[str] = []):
|
def get_voices(extra_voice_dirs: List[str] = []):
|
||||||
dirs = extra_voice_dirs
|
dirs = extra_voice_dirs
|
||||||
voices: Dict[str, List[str]] = {}
|
voices: Dict[str, List[str]] = {}
|
||||||
|
@ -175,7 +157,7 @@ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
|
||||||
)
|
)
|
||||||
stft = stft.to(device)
|
stft = stft.to(device)
|
||||||
mel = stft(wav)
|
mel = stft(wav)
|
||||||
mel = dynamic_range_compression(mel)
|
mel = amp_to_db(mel)
|
||||||
if do_normalization:
|
if do_normalization:
|
||||||
mel = normalize_tacotron_mel(mel)
|
mel = normalize_tacotron_mel(mel)
|
||||||
return mel
|
return mel
|
||||||
|
|
|
@ -32,6 +32,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
||||||
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
||||||
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
||||||
from TTS.utils.audio.processor import AudioProcessor
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
|
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
@ -136,24 +137,6 @@ def load_audio(file_path: str):
|
||||||
return x, sr
|
return x, sr
|
||||||
|
|
||||||
|
|
||||||
def _amp_to_db(x, C=1, clip_val=1e-5):
|
|
||||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
||||||
|
|
||||||
|
|
||||||
def _db_to_amp(x, C=1):
|
|
||||||
return torch.exp(x) / C
|
|
||||||
|
|
||||||
|
|
||||||
def amp_to_db(magnitudes):
|
|
||||||
output = _amp_to_db(magnitudes)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def db_to_amp(magnitudes):
|
|
||||||
output = _db_to_amp(magnitudes)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||||
y = y.squeeze(1)
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
|
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||||
from TTS.utils.samplers import BucketBatchSampler
|
from TTS.utils.samplers import BucketBatchSampler
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
@ -78,24 +79,6 @@ def load_audio(file_path):
|
||||||
return x, sr
|
return x, sr
|
||||||
|
|
||||||
|
|
||||||
def _amp_to_db(x, C=1, clip_val=1e-5):
|
|
||||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
||||||
|
|
||||||
|
|
||||||
def _db_to_amp(x, C=1):
|
|
||||||
return torch.exp(x) / C
|
|
||||||
|
|
||||||
|
|
||||||
def amp_to_db(magnitudes):
|
|
||||||
output = _amp_to_db(magnitudes)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def db_to_amp(magnitudes):
|
|
||||||
output = _db_to_amp(magnitudes)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||||
"""
|
"""
|
||||||
Args Shapes:
|
Args Shapes:
|
||||||
|
|
|
@ -59,7 +59,7 @@ def _exp(x, base):
|
||||||
return np.exp(x)
|
return np.exp(x)
|
||||||
|
|
||||||
|
|
||||||
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: float = 10, **kwargs) -> np.ndarray:
|
||||||
"""Convert amplitude values to decibels.
|
"""Convert amplitude values to decibels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -3,6 +3,16 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
|
||||||
|
"""Spectral normalization / dynamic range compression."""
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * spec_gain)
|
||||||
|
|
||||||
|
|
||||||
|
def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
|
||||||
|
"""Spectral denormalization / dynamic range decompression."""
|
||||||
|
return torch.exp(x) / spec_gain
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||||
|
|
||||||
|
@ -157,11 +167,3 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
norm=self.mel_norm,
|
norm=self.mel_norm,
|
||||||
)
|
)
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _amp_to_db(x, spec_gain=1.0):
|
|
||||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _db_to_amp(x, spec_gain=1.0):
|
|
||||||
return torch.exp(x) / spec_gain
|
|
||||||
|
|
|
@ -4,39 +4,12 @@ import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from librosa.filters import mel as librosa_mel_fn
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
|
||||||
|
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_WAV_VALUE = 32768.0
|
MAX_WAV_VALUE = 32768.0
|
||||||
|
|
||||||
|
|
||||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
|
||||||
"""
|
|
||||||
PARAMS
|
|
||||||
------
|
|
||||||
C: compression factor
|
|
||||||
"""
|
|
||||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
||||||
|
|
||||||
|
|
||||||
def dynamic_range_decompression_torch(x, C=1):
|
|
||||||
"""
|
|
||||||
PARAMS
|
|
||||||
------
|
|
||||||
C: compression factor used to compress
|
|
||||||
"""
|
|
||||||
return torch.exp(x) / C
|
|
||||||
|
|
||||||
|
|
||||||
def spectral_normalize_torch(magnitudes):
|
|
||||||
output = dynamic_range_compression_torch(magnitudes)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def spectral_de_normalize_torch(magnitudes):
|
|
||||||
output = dynamic_range_decompression_torch(magnitudes)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
mel_basis = {}
|
mel_basis = {}
|
||||||
hann_window = {}
|
hann_window = {}
|
||||||
|
|
||||||
|
@ -85,7 +58,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||||
spec = spectral_normalize_torch(spec)
|
spec = amp_to_db(spec)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,6 +101,6 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size,
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
|
||||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||||
spec = spectral_normalize_torch(spec)
|
spec = amp_to_db(spec)
|
||||||
|
|
||||||
return spec
|
return spec
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from TTS.utils.audio import numpy_transforms as np_transforms
|
||||||
|
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
|
||||||
|
|
||||||
|
|
||||||
|
def test_amplitude_db_conversion():
|
||||||
|
x = torch.rand(11)
|
||||||
|
o1 = amp_to_db(x=x, spec_gain=1.0)
|
||||||
|
o2 = db_to_amp(x=o1, spec_gain=1.0)
|
||||||
|
np_o1 = np_transforms.amp_to_db(x=x, base=np.e)
|
||||||
|
np_o2 = np_transforms.db_to_amp(x=np_o1, base=np.e)
|
||||||
|
assert torch.allclose(x, o2)
|
||||||
|
assert torch.allclose(o1, np_o1)
|
||||||
|
assert torch.allclose(o2, np_o2)
|
|
@ -13,14 +13,13 @@ from TTS.tts.models.vits import (
|
||||||
Vits,
|
Vits,
|
||||||
VitsArgs,
|
VitsArgs,
|
||||||
VitsAudioConfig,
|
VitsAudioConfig,
|
||||||
amp_to_db,
|
|
||||||
db_to_amp,
|
|
||||||
load_audio,
|
load_audio,
|
||||||
spec_to_mel,
|
spec_to_mel,
|
||||||
wav_to_mel,
|
wav_to_mel,
|
||||||
wav_to_spec,
|
wav_to_spec,
|
||||||
)
|
)
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
|
||||||
|
|
||||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
|
|
Loading…
Reference in New Issue