refactor(audio): remove duplicate save_wav code

This commit is contained in:
Enno Hermann 2024-11-18 00:54:26 +01:00
parent 5784f6705a
commit 8ba3233ec6
2 changed files with 29 additions and 18 deletions

View File

@ -431,7 +431,16 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool
return x return x
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **kwargs) -> None: def save_wav(
*,
wav: np.ndarray,
path: str,
sample_rate: int,
pipe_out=None,
do_rms_norm: bool = False,
db_level: float = -27.0,
**kwargs,
) -> None:
"""Save float waveform to a file using Scipy. """Save float waveform to a file using Scipy.
Args: Args:
@ -439,8 +448,16 @@ def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **k
path (str): Path to a output file. path (str): Path to a output file.
sr (int): Sampling rate used for saving to the file. Defaults to None. sr (int): Sampling rate used for saving to the file. Defaults to None.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
do_rms_norm (bool): Whether to apply RMS normalization
db_level (float): Target dB level in RMS.
""" """
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) if do_rms_norm:
if db_level is None:
msg = "`db_level` cannot be None with `do_rms_norm=True`"
raise ValueError(msg)
wav_norm = rms_volume_norm(x=wav, db_level=db_level)
else:
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
wav_norm = wav_norm.astype(np.int16) wav_norm = wav_norm.astype(np.int16)
if pipe_out: if pipe_out:

View File

@ -1,11 +1,8 @@
import logging import logging
from io import BytesIO
from typing import Optional from typing import Optional
import librosa import librosa
import numpy as np import numpy as np
import scipy.io.wavfile
import scipy.signal
from TTS.tts.utils.helpers import StandardScaler from TTS.tts.utils.helpers import StandardScaler
from TTS.utils.audio.numpy_transforms import ( from TTS.utils.audio.numpy_transforms import (
@ -21,6 +18,7 @@ from TTS.utils.audio.numpy_transforms import (
millisec_to_length, millisec_to_length,
preemphasis, preemphasis,
rms_volume_norm, rms_volume_norm,
save_wav,
spec_to_mel, spec_to_mel,
stft, stft,
trim_silence, trim_silence,
@ -590,7 +588,7 @@ class AudioProcessor:
x = self.rms_volume_norm(x, self.db_level) x = self.rms_volume_norm(x, self.db_level)
return x return x
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None: def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None:
"""Save a waveform to a file using Scipy. """Save a waveform to a file using Scipy.
Args: Args:
@ -599,18 +597,14 @@ class AudioProcessor:
sr (int, optional): Sampling rate used for saving to the file. Defaults to None. sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
""" """
if self.do_rms_norm: save_wav(
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767 wav=wav,
else: path=path,
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) sample_rate=sr if sr else self.sample_rate,
pipe_out=pipe_out,
wav_norm = wav_norm.astype(np.int16) do_rms_norm=self.do_rms_norm,
if pipe_out: db_level=self.db_level,
wav_buffer = BytesIO() )
scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
wav_buffer.seek(0)
pipe_out.buffer.write(wav_buffer.read())
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
def get_duration(self, filename: str) -> float: def get_duration(self, filename: str) -> float:
"""Get the duration of a wav file using Librosa. """Get the duration of a wav file using Librosa.