From 8ba3233ec607c648e3720086ff75994c1b39677f Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 18 Nov 2024 00:54:26 +0100 Subject: [PATCH] refactor(audio): remove duplicate save_wav code --- TTS/utils/audio/numpy_transforms.py | 21 +++++++++++++++++++-- TTS/utils/audio/processor.py | 26 ++++++++++---------------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index cf717c7a..203091ea 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -431,7 +431,16 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool return x -def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **kwargs) -> None: +def save_wav( + *, + wav: np.ndarray, + path: str, + sample_rate: int, + pipe_out=None, + do_rms_norm: bool = False, + db_level: float = -27.0, + **kwargs, +) -> None: """Save float waveform to a file using Scipy. Args: @@ -439,8 +448,16 @@ def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **k path (str): Path to a output file. sr (int): Sampling rate used for saving to the file. Defaults to None. pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + do_rms_norm (bool): Whether to apply RMS normalization + db_level (float): Target dB level in RMS. """ - wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + if do_rms_norm: + if db_level is None: + msg = "`db_level` cannot be None with `do_rms_norm=True`" + raise ValueError(msg) + wav_norm = rms_volume_norm(x=wav, db_level=db_level) + else: + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) wav_norm = wav_norm.astype(np.int16) if pipe_out: diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index 90d0d755..fe125ace 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -1,11 +1,8 @@ import logging -from io import BytesIO from typing import Optional import librosa import numpy as np -import scipy.io.wavfile -import scipy.signal from TTS.tts.utils.helpers import StandardScaler from TTS.utils.audio.numpy_transforms import ( @@ -21,6 +18,7 @@ from TTS.utils.audio.numpy_transforms import ( millisec_to_length, preemphasis, rms_volume_norm, + save_wav, spec_to_mel, stft, trim_silence, @@ -590,7 +588,7 @@ class AudioProcessor: x = self.rms_volume_norm(x, self.db_level) return x - def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None: + def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None: """Save a waveform to a file using Scipy. Args: @@ -599,18 +597,14 @@ class AudioProcessor: sr (int, optional): Sampling rate used for saving to the file. Defaults to None. pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. """ - if self.do_rms_norm: - wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767 - else: - wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) - - wav_norm = wav_norm.astype(np.int16) - if pipe_out: - wav_buffer = BytesIO() - scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm) - wav_buffer.seek(0) - pipe_out.buffer.write(wav_buffer.read()) - scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm) + save_wav( + wav=wav, + path=path, + sample_rate=sr if sr else self.sample_rate, + pipe_out=pipe_out, + do_rms_norm=self.do_rms_norm, + db_level=self.db_level, + ) def get_duration(self, filename: str) -> float: """Get the duration of a wav file using Librosa.