refactor(audio.processor): remove duplicate quantization methods

This commit is contained in:
Enno Hermann 2023-11-15 16:19:56 +01:00
parent ddbaecdb5b
commit 8f1db7510a
6 changed files with 17 additions and 45 deletions

View File

@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import count_parameters
use_cuda = torch.cuda.is_available()
@ -197,7 +198,7 @@ def extract_spectrograms(
# quantize and save wav
if quantize_bits > 0:
wavq = ap.quantize(wav, quantize_bits)
wavq = quantize(wav, quantize_bits)
np.save(wavq_path, wavq)
# save TTS mel

View File

@ -631,43 +631,3 @@ class AudioProcessor(object):
filename (str): Path to the wav file.
"""
return librosa.get_duration(filename=filename)
@staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)
@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
@staticmethod
def encode_16bits(x):
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
@staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**bits - 1) / 2
@staticmethod
def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**bits - 1) - 1

View File

@ -7,6 +7,7 @@ from coqpit import Coqpit
from tqdm import tqdm
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
mel = ap.melspectrogram(y)
np.save(mel_path, mel)
if isinstance(config.mode, int):
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
quant = (
mulaw_encode(wav=y, mulaw_qc=config.mode)
if config.model_args.mulaw
else quantize(x=y, quantize_bits=config.mode)
)
np.save(quant_path, quant)

View File

@ -2,6 +2,8 @@ import numpy as np
import torch
from torch.utils.data import Dataset
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
class WaveRNNDataset(Dataset):
"""
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
x_input = audio
elif isinstance(self.mode, int):
x_input = (
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
mulaw_encode(wav=audio, mulaw_qc=self.mode)
if self.mulaw
else quantize(x=audio, quantize_bits=self.mode)
)
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
output = output[0]
if self.args.mulaw and isinstance(self.args.mode, int):
output = AudioProcessor.mulaw_decode(output, self.args.mode)
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)

View File

@ -34,6 +34,7 @@
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.audio.numpy_transforms import quantize\n",
"\n",
"%matplotlib inline\n",
"\n",
@ -190,7 +191,7 @@
"\n",
" # quantize and save wav\n",
" if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
" wavq = quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n",
"\n",
" # save TTS mel\n",