mirror of https://github.com/coqui-ai/TTS.git
refactor(audio): remove duplicate save_wav code
This commit is contained in:
parent
5784f6705a
commit
8ba3233ec6
|
@ -431,7 +431,16 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool
|
|||
return x
|
||||
|
||||
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **kwargs) -> None:
|
||||
def save_wav(
|
||||
*,
|
||||
wav: np.ndarray,
|
||||
path: str,
|
||||
sample_rate: int,
|
||||
pipe_out=None,
|
||||
do_rms_norm: bool = False,
|
||||
db_level: float = -27.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Save float waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
|
@ -439,8 +448,16 @@ def save_wav(*, wav: np.ndarray, path: str, sample_rate: int, pipe_out=None, **k
|
|||
path (str): Path to a output file.
|
||||
sr (int): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
do_rms_norm (bool): Whether to apply RMS normalization
|
||||
db_level (float): Target dB level in RMS.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
if do_rms_norm:
|
||||
if db_level is None:
|
||||
msg = "`db_level` cannot be None with `do_rms_norm=True`"
|
||||
raise ValueError(msg)
|
||||
wav_norm = rms_volume_norm(x=wav, db_level=db_level)
|
||||
else:
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
|
||||
from TTS.tts.utils.helpers import StandardScaler
|
||||
from TTS.utils.audio.numpy_transforms import (
|
||||
|
@ -21,6 +18,7 @@ from TTS.utils.audio.numpy_transforms import (
|
|||
millisec_to_length,
|
||||
preemphasis,
|
||||
rms_volume_norm,
|
||||
save_wav,
|
||||
spec_to_mel,
|
||||
stft,
|
||||
trim_silence,
|
||||
|
@ -590,7 +588,7 @@ class AudioProcessor:
|
|||
x = self.rms_volume_norm(x, self.db_level)
|
||||
return x
|
||||
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None:
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None:
|
||||
"""Save a waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
|
@ -599,18 +597,14 @@ class AudioProcessor:
|
|||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
if self.do_rms_norm:
|
||||
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
|
||||
else:
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
wav_buffer = BytesIO()
|
||||
scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
|
||||
wav_buffer.seek(0)
|
||||
pipe_out.buffer.write(wav_buffer.read())
|
||||
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
|
||||
save_wav(
|
||||
wav=wav,
|
||||
path=path,
|
||||
sample_rate=sr if sr else self.sample_rate,
|
||||
pipe_out=pipe_out,
|
||||
do_rms_norm=self.do_rms_norm,
|
||||
db_level=self.db_level,
|
||||
)
|
||||
|
||||
def get_duration(self, filename: str) -> float:
|
||||
"""Get the duration of a wav file using Librosa.
|
||||
|
|
Loading…
Reference in New Issue