refactor(audio): remove duplicate save_wav code

This commit is contained in:
Enno Hermann 2024-11-18 00:54:26 +01:00
parent 5784f6705a
commit 8ba3233ec6
2 changed files with 29 additions and 18 deletions

View File

@ -431,7 +431,16 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool
return x
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **kwargs) -> None:
def save_wav(
*,
wav: np.ndarray,
path: str,
sample_rate: int,
pipe_out=None,
do_rms_norm: bool = False,
db_level: float = -27.0,
**kwargs,
) -> None:
"""Save float waveform to a file using Scipy.
Args:
@ -439,8 +448,16 @@ def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **k
path (str): Path to a output file.
sr (int): Sampling rate used for saving to the file. Defaults to None.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
do_rms_norm (bool): Whether to apply RMS normalization
db_level (float): Target dB level in RMS.
"""
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if do_rms_norm:
if db_level is None:
msg = "`db_level` cannot be None with `do_rms_norm=True`"
raise ValueError(msg)
wav_norm = rms_volume_norm(x=wav, db_level=db_level)
else:
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
wav_norm = wav_norm.astype(np.int16)
if pipe_out:

View File

@ -1,11 +1,8 @@
import logging
from io import BytesIO
from typing import Optional
import librosa
import numpy as np
import scipy.io.wavfile
import scipy.signal
from TTS.tts.utils.helpers import StandardScaler
from TTS.utils.audio.numpy_transforms import (
@ -21,6 +18,7 @@ from TTS.utils.audio.numpy_transforms import (
millisec_to_length,
preemphasis,
rms_volume_norm,
save_wav,
spec_to_mel,
stft,
trim_silence,
@ -590,7 +588,7 @@ class AudioProcessor:
x = self.rms_volume_norm(x, self.db_level)
return x
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None:
def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None:
"""Save a waveform to a file using Scipy.
Args:
@ -599,18 +597,14 @@ class AudioProcessor:
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
"""
if self.do_rms_norm:
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
else:
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
wav_norm = wav_norm.astype(np.int16)
if pipe_out:
wav_buffer = BytesIO()
scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
wav_buffer.seek(0)
pipe_out.buffer.write(wav_buffer.read())
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
save_wav(
wav=wav,
path=path,
sample_rate=sr if sr else self.sample_rate,
pipe_out=pipe_out,
do_rms_norm=self.do_rms_norm,
db_level=self.db_level,
)
def get_duration(self, filename: str) -> float:
"""Get the duration of a wav file using Librosa.