mirror of https://github.com/coqui-ai/TTS.git
refactor(audio.processor): remove duplicate quantization methods
This commit is contained in:
parent
ddbaecdb5b
commit
8f1db7510a
|
@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import quantize
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
@ -197,7 +198,7 @@ def extract_spectrograms(
|
|||
|
||||
# quantize and save wav
|
||||
if quantize_bits > 0:
|
||||
wavq = ap.quantize(wav, quantize_bits)
|
||||
wavq = quantize(wav, quantize_bits)
|
||||
np.save(wavq_path, wavq)
|
||||
|
||||
# save TTS mel
|
||||
|
|
|
@ -631,43 +631,3 @@ class AudioProcessor(object):
|
|||
filename (str): Path to the wav file.
|
||||
"""
|
||||
return librosa.get_duration(filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
|
||||
mu = 2**qc - 1
|
||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
# Quantize signal to the specified number of levels.
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mulaw_decode(wav, qc):
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2**qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def encode_16bits(x):
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
@staticmethod
|
||||
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2**bits - 1) / 2
|
||||
|
||||
@staticmethod
|
||||
def dequantize(x, bits):
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2**bits - 1) - 1
|
||||
|
|
|
@ -7,6 +7,7 @@ from coqpit import Coqpit
|
|||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||
|
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
|||
mel = ap.melspectrogram(y)
|
||||
np.save(mel_path, mel)
|
||||
if isinstance(config.mode, int):
|
||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
|
||||
quant = (
|
||||
mulaw_encode(wav=y, mulaw_qc=config.mode)
|
||||
if config.model_args.mulaw
|
||||
else quantize(x=y, quantize_bits=config.mode)
|
||||
)
|
||||
np.save(quant_path, quant)
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
class WaveRNNDataset(Dataset):
|
||||
"""
|
||||
|
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
|
|||
x_input = audio
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = (
|
||||
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
|
||||
mulaw_encode(wav=audio, mulaw_qc=self.mode)
|
||||
if self.mulaw
|
||||
else quantize(x=audio, quantize_bits=self.mode)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
|
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
|
|||
output = output[0]
|
||||
|
||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||
output = AudioProcessor.mulaw_decode(output, self.args.mode)
|
||||
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||
|
||||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
|
||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.audio.numpy_transforms import quantize\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
|
@ -190,7 +191,7 @@
|
|||
"\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZE_BITS > 0:\n",
|
||||
" wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
|
||||
" wavq = quantize(wav, QUANTIZE_BITS)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
|
|
Loading…
Reference in New Issue