mirror of https://github.com/coqui-ai/TTS.git
refactor: move more audio processing into torch_transforms
This commit is contained in:
parent
2c82477a78
commit
76df6421de
|
@ -9,7 +9,6 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
@ -38,7 +37,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
|||
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
||||
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||
from TTS.utils.audio.torch_transforms import wav_to_mel, wav_to_spec
|
||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -50,145 +49,11 @@ hann_window = {}
|
|||
mel_basis = {}
|
||||
|
||||
|
||||
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window # pylint: disable=global-statement
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def wav_to_energy(y, n_fft, hop_length, win_length, center=False):
|
||||
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
return torch.norm(spec, dim=1, keepdim=True)
|
||||
|
||||
|
||||
def name_mel_basis(spec, n_fft, fmax):
|
||||
n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
|
||||
return n_fft_len
|
||||
|
||||
|
||||
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
||||
"""
|
||||
Args Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
|
||||
Return Shapes:
|
||||
- mel : :math:`[B,C,T]`
|
||||
"""
|
||||
global mel_basis # pylint: disable=global-statement
|
||||
mel_basis_key = name_mel_basis(spec, n_fft, fmax)
|
||||
# pylint: disable=too-many-function-args
|
||||
if mel_basis_key not in mel_basis:
|
||||
# pylint: disable=missing-kwoa
|
||||
mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel = torch.matmul(mel_basis[mel_basis_key], spec)
|
||||
mel = amp_to_db(mel)
|
||||
return mel
|
||||
|
||||
|
||||
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T_y]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T_spec]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
mel_basis_key = name_mel_basis(y, n_fft, fmax)
|
||||
wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device)
|
||||
if mel_basis_key not in mel_basis:
|
||||
# pylint: disable=missing-kwoa
|
||||
mel = librosa_mel_fn(
|
||||
sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
) # pylint: disable=too-many-function-args
|
||||
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = torch.matmul(mel_basis[mel_basis_key], spec)
|
||||
spec = amp_to_db(spec)
|
||||
return spec
|
||||
|
||||
|
||||
##############################
|
||||
# DATASET
|
||||
##############################
|
||||
|
|
|
@ -10,7 +10,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from monotonic_alignment_search import maximum_path
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
@ -35,7 +34,7 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||
from TTS.utils.audio.torch_transforms import spec_to_mel, wav_to_mel, wav_to_spec
|
||||
from TTS.utils.samplers import BucketBatchSampler
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -46,10 +45,6 @@ logger = logging.getLogger(__name__)
|
|||
# IO / Feature extraction
|
||||
##############################
|
||||
|
||||
# pylint: disable=global-statement
|
||||
hann_window = {}
|
||||
mel_basis = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def weights_reset(m: nn.Module):
|
||||
|
@ -79,125 +74,6 @@ def load_audio(file_path):
|
|||
return x, sr
|
||||
|
||||
|
||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
||||
"""
|
||||
Args Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
|
||||
Return Shapes:
|
||||
- mel : :math:`[B,C,T]`
|
||||
"""
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
mel = amp_to_db(mel)
|
||||
return mel
|
||||
|
||||
|
||||
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = amp_to_db(spec)
|
||||
return spec
|
||||
|
||||
|
||||
#############################
|
||||
# CONFIGS
|
||||
#############################
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
import logging
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
hann_window = {}
|
||||
mel_basis = {}
|
||||
|
||||
|
||||
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
|
||||
"""Spectral normalization / dynamic range compression."""
|
||||
|
@ -13,6 +21,94 @@ def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
|
|||
return torch.exp(x) / spec_gain
|
||||
|
||||
|
||||
def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}"
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
return torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
|
||||
def spec_to_mel(
|
||||
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
|
||||
Return Shapes:
|
||||
- mel : :math:`[B,C,T]`
|
||||
"""
|
||||
global mel_basis
|
||||
fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
# TODO: switch librosa to torchaudio
|
||||
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
return amp_to_db(mel)
|
||||
|
||||
|
||||
def wav_to_mel(
|
||||
y: torch.Tensor,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
sample_rate: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
*,
|
||||
center: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
|
|
|
@ -14,54 +14,6 @@ mel_basis = {}
|
|||
hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("Min value is: %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("Max value is: %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = amp_to_db(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("Min value is: %.3f", torch.min(y))
|
||||
|
|
|
@ -14,12 +14,9 @@ from TTS.tts.models.vits import (
|
|||
VitsArgs,
|
||||
VitsAudioConfig,
|
||||
load_audio,
|
||||
spec_to_mel,
|
||||
wav_to_mel,
|
||||
wav_to_spec,
|
||||
)
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp, spec_to_mel, wav_to_mel, wav_to_spec
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||
|
|
Loading…
Reference in New Issue