mirror of https://github.com/coqui-ai/TTS.git
refactor: move more audio processing into torch_transforms
This commit is contained in:
parent
2c82477a78
commit
76df6421de
|
@ -9,7 +9,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from librosa.filters import mel as librosa_mel_fn
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
@ -38,7 +37,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
||||||
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
||||||
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
||||||
from TTS.utils.audio.processor import AudioProcessor
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
from TTS.utils.audio.torch_transforms import wav_to_mel, wav_to_spec
|
||||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
@ -50,145 +49,11 @@ hann_window = {}
|
||||||
mel_basis = {}
|
mel_basis = {}
|
||||||
|
|
||||||
|
|
||||||
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
if torch.min(y) < -1.0:
|
|
||||||
logger.info("min value is %.3f", torch.min(y))
|
|
||||||
if torch.max(y) > 1.0:
|
|
||||||
logger.info("max value is %.3f", torch.max(y))
|
|
||||||
|
|
||||||
global hann_window # pylint: disable=global-statement
|
|
||||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
|
||||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
|
||||||
if wnsize_dtype_device not in hann_window:
|
|
||||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
|
||||||
y.unsqueeze(1),
|
|
||||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
|
||||||
mode="reflect",
|
|
||||||
)
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
spec = torch.view_as_real(
|
|
||||||
torch.stft(
|
|
||||||
y,
|
|
||||||
n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
window=hann_window[wnsize_dtype_device],
|
|
||||||
center=center,
|
|
||||||
pad_mode="reflect",
|
|
||||||
normalized=False,
|
|
||||||
onesided=True,
|
|
||||||
return_complex=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
|
||||||
"""
|
|
||||||
Args Shapes:
|
|
||||||
- y : :math:`[B, 1, T]`
|
|
||||||
|
|
||||||
Return Shapes:
|
|
||||||
- spec : :math:`[B,C,T]`
|
|
||||||
"""
|
|
||||||
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def wav_to_energy(y, n_fft, hop_length, win_length, center=False):
|
def wav_to_energy(y, n_fft, hop_length, win_length, center=False):
|
||||||
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||||
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
||||||
return torch.norm(spec, dim=1, keepdim=True)
|
return torch.norm(spec, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
def name_mel_basis(spec, n_fft, fmax):
|
|
||||||
n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
|
|
||||||
return n_fft_len
|
|
||||||
|
|
||||||
|
|
||||||
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
|
||||||
"""
|
|
||||||
Args Shapes:
|
|
||||||
- spec : :math:`[B,C,T]`
|
|
||||||
|
|
||||||
Return Shapes:
|
|
||||||
- mel : :math:`[B,C,T]`
|
|
||||||
"""
|
|
||||||
global mel_basis # pylint: disable=global-statement
|
|
||||||
mel_basis_key = name_mel_basis(spec, n_fft, fmax)
|
|
||||||
# pylint: disable=too-many-function-args
|
|
||||||
if mel_basis_key not in mel_basis:
|
|
||||||
# pylint: disable=missing-kwoa
|
|
||||||
mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax)
|
|
||||||
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
|
||||||
mel = torch.matmul(mel_basis[mel_basis_key], spec)
|
|
||||||
mel = amp_to_db(mel)
|
|
||||||
return mel
|
|
||||||
|
|
||||||
|
|
||||||
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
|
|
||||||
"""
|
|
||||||
Args Shapes:
|
|
||||||
- y : :math:`[B, 1, T_y]`
|
|
||||||
|
|
||||||
Return Shapes:
|
|
||||||
- spec : :math:`[B,C,T_spec]`
|
|
||||||
"""
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
if torch.min(y) < -1.0:
|
|
||||||
logger.info("min value is %.3f", torch.min(y))
|
|
||||||
if torch.max(y) > 1.0:
|
|
||||||
logger.info("max value is %.3f", torch.max(y))
|
|
||||||
|
|
||||||
global mel_basis, hann_window # pylint: disable=global-statement
|
|
||||||
mel_basis_key = name_mel_basis(y, n_fft, fmax)
|
|
||||||
wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device)
|
|
||||||
if mel_basis_key not in mel_basis:
|
|
||||||
# pylint: disable=missing-kwoa
|
|
||||||
mel = librosa_mel_fn(
|
|
||||||
sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
|
||||||
) # pylint: disable=too-many-function-args
|
|
||||||
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
|
||||||
if wnsize_dtype_device not in hann_window:
|
|
||||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
|
||||||
y.unsqueeze(1),
|
|
||||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
|
||||||
mode="reflect",
|
|
||||||
)
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
spec = torch.view_as_real(
|
|
||||||
torch.stft(
|
|
||||||
y,
|
|
||||||
n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
window=hann_window[wnsize_dtype_device],
|
|
||||||
center=center,
|
|
||||||
pad_mode="reflect",
|
|
||||||
normalized=False,
|
|
||||||
onesided=True,
|
|
||||||
return_complex=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
||||||
spec = torch.matmul(mel_basis[mel_basis_key], spec)
|
|
||||||
spec = amp_to_db(spec)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# DATASET
|
# DATASET
|
||||||
##############################
|
##############################
|
||||||
|
|
|
@ -10,7 +10,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from librosa.filters import mel as librosa_mel_fn
|
|
||||||
from monotonic_alignment_search import maximum_path
|
from monotonic_alignment_search import maximum_path
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
@ -35,7 +34,7 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
from TTS.utils.audio.torch_transforms import spec_to_mel, wav_to_mel, wav_to_spec
|
||||||
from TTS.utils.samplers import BucketBatchSampler
|
from TTS.utils.samplers import BucketBatchSampler
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
@ -46,10 +45,6 @@ logger = logging.getLogger(__name__)
|
||||||
# IO / Feature extraction
|
# IO / Feature extraction
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
# pylint: disable=global-statement
|
|
||||||
hann_window = {}
|
|
||||||
mel_basis = {}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def weights_reset(m: nn.Module):
|
def weights_reset(m: nn.Module):
|
||||||
|
@ -79,125 +74,6 @@ def load_audio(file_path):
|
||||||
return x, sr
|
return x, sr
|
||||||
|
|
||||||
|
|
||||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
|
||||||
"""
|
|
||||||
Args Shapes:
|
|
||||||
- y : :math:`[B, 1, T]`
|
|
||||||
|
|
||||||
Return Shapes:
|
|
||||||
- spec : :math:`[B,C,T]`
|
|
||||||
"""
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
if torch.min(y) < -1.0:
|
|
||||||
logger.info("min value is %.3f", torch.min(y))
|
|
||||||
if torch.max(y) > 1.0:
|
|
||||||
logger.info("max value is %.3f", torch.max(y))
|
|
||||||
|
|
||||||
global hann_window
|
|
||||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
|
||||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
|
||||||
if wnsize_dtype_device not in hann_window:
|
|
||||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
|
||||||
y.unsqueeze(1),
|
|
||||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
|
||||||
mode="reflect",
|
|
||||||
)
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
spec = torch.view_as_real(
|
|
||||||
torch.stft(
|
|
||||||
y,
|
|
||||||
n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
window=hann_window[wnsize_dtype_device],
|
|
||||||
center=center,
|
|
||||||
pad_mode="reflect",
|
|
||||||
normalized=False,
|
|
||||||
onesided=True,
|
|
||||||
return_complex=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
|
||||||
"""
|
|
||||||
Args Shapes:
|
|
||||||
- spec : :math:`[B,C,T]`
|
|
||||||
|
|
||||||
Return Shapes:
|
|
||||||
- mel : :math:`[B,C,T]`
|
|
||||||
"""
|
|
||||||
global mel_basis
|
|
||||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
|
||||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
|
||||||
if fmax_dtype_device not in mel_basis:
|
|
||||||
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
||||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
|
||||||
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
|
||||||
mel = amp_to_db(mel)
|
|
||||||
return mel
|
|
||||||
|
|
||||||
|
|
||||||
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
|
|
||||||
"""
|
|
||||||
Args Shapes:
|
|
||||||
- y : :math:`[B, 1, T]`
|
|
||||||
|
|
||||||
Return Shapes:
|
|
||||||
- spec : :math:`[B,C,T]`
|
|
||||||
"""
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
if torch.min(y) < -1.0:
|
|
||||||
logger.info("min value is %.3f", torch.min(y))
|
|
||||||
if torch.max(y) > 1.0:
|
|
||||||
logger.info("max value is %.3f", torch.max(y))
|
|
||||||
|
|
||||||
global mel_basis, hann_window
|
|
||||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
|
||||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
|
||||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
|
||||||
if fmax_dtype_device not in mel_basis:
|
|
||||||
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
||||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
|
||||||
if wnsize_dtype_device not in hann_window:
|
|
||||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
|
||||||
y.unsqueeze(1),
|
|
||||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
|
||||||
mode="reflect",
|
|
||||||
)
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
spec = torch.view_as_real(
|
|
||||||
torch.stft(
|
|
||||||
y,
|
|
||||||
n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
window=hann_window[wnsize_dtype_device],
|
|
||||||
center=center,
|
|
||||||
pad_mode="reflect",
|
|
||||||
normalized=False,
|
|
||||||
onesided=True,
|
|
||||||
return_complex=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
||||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
|
||||||
spec = amp_to_db(spec)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# CONFIGS
|
# CONFIGS
|
||||||
#############################
|
#############################
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
hann_window = {}
|
||||||
|
mel_basis = {}
|
||||||
|
|
||||||
|
|
||||||
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
|
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
|
||||||
"""Spectral normalization / dynamic range compression."""
|
"""Spectral normalization / dynamic range compression."""
|
||||||
|
@ -13,6 +21,94 @@ def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
|
||||||
return torch.exp(x) / spec_gain
|
return torch.exp(x) / spec_gain
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args Shapes:
|
||||||
|
- y : :math:`[B, 1, T]`
|
||||||
|
|
||||||
|
Return Shapes:
|
||||||
|
- spec : :math:`[B,C,T]`
|
||||||
|
"""
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
if torch.min(y) < -1.0:
|
||||||
|
logger.info("min value is %.3f", torch.min(y))
|
||||||
|
if torch.max(y) > 1.0:
|
||||||
|
logger.info("max value is %.3f", torch.max(y))
|
||||||
|
|
||||||
|
global hann_window
|
||||||
|
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}"
|
||||||
|
if wnsize_dtype_device not in hann_window:
|
||||||
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||||
|
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1),
|
||||||
|
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||||
|
mode="reflect",
|
||||||
|
)
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.view_as_real(
|
||||||
|
torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=hann_window[wnsize_dtype_device],
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def spec_to_mel(
|
||||||
|
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args Shapes:
|
||||||
|
- spec : :math:`[B,C,T]`
|
||||||
|
|
||||||
|
Return Shapes:
|
||||||
|
- mel : :math:`[B,C,T]`
|
||||||
|
"""
|
||||||
|
global mel_basis
|
||||||
|
fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
|
||||||
|
if fmax_dtype_device not in mel_basis:
|
||||||
|
# TODO: switch librosa to torchaudio
|
||||||
|
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||||
|
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||||
|
return amp_to_db(mel)
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_mel(
|
||||||
|
y: torch.Tensor,
|
||||||
|
n_fft: int,
|
||||||
|
num_mels: int,
|
||||||
|
sample_rate: int,
|
||||||
|
hop_length: int,
|
||||||
|
win_length: int,
|
||||||
|
fmin: float,
|
||||||
|
fmax: float,
|
||||||
|
*,
|
||||||
|
center: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args Shapes:
|
||||||
|
- y : :math:`[B, 1, T]`
|
||||||
|
|
||||||
|
Return Shapes:
|
||||||
|
- spec : :math:`[B,C,T]`
|
||||||
|
"""
|
||||||
|
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||||
|
return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||||
|
|
||||||
|
|
|
@ -14,54 +14,6 @@ mel_basis = {}
|
||||||
hann_window = {}
|
hann_window = {}
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
|
||||||
if torch.min(y) < -1.0:
|
|
||||||
logger.info("Min value is: %.3f", torch.min(y))
|
|
||||||
if torch.max(y) > 1.0:
|
|
||||||
logger.info("Max value is: %.3f", torch.max(y))
|
|
||||||
|
|
||||||
global hann_window
|
|
||||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
|
||||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
|
||||||
if wnsize_dtype_device not in hann_window:
|
|
||||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
|
||||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
|
||||||
)
|
|
||||||
y = y.squeeze(1)
|
|
||||||
|
|
||||||
spec = torch.view_as_real(
|
|
||||||
torch.stft(
|
|
||||||
y,
|
|
||||||
n_fft,
|
|
||||||
hop_length=hop_size,
|
|
||||||
win_length=win_size,
|
|
||||||
window=hann_window[wnsize_dtype_device],
|
|
||||||
center=center,
|
|
||||||
pad_mode="reflect",
|
|
||||||
normalized=False,
|
|
||||||
onesided=True,
|
|
||||||
return_complex=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
|
||||||
global mel_basis
|
|
||||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
|
||||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
|
||||||
if fmax_dtype_device not in mel_basis:
|
|
||||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
||||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
|
||||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
|
||||||
spec = amp_to_db(spec)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||||
if torch.min(y) < -1.0:
|
if torch.min(y) < -1.0:
|
||||||
logger.info("Min value is: %.3f", torch.min(y))
|
logger.info("Min value is: %.3f", torch.min(y))
|
||||||
|
|
|
@ -14,12 +14,9 @@ from TTS.tts.models.vits import (
|
||||||
VitsArgs,
|
VitsArgs,
|
||||||
VitsAudioConfig,
|
VitsAudioConfig,
|
||||||
load_audio,
|
load_audio,
|
||||||
spec_to_mel,
|
|
||||||
wav_to_mel,
|
|
||||||
wav_to_spec,
|
|
||||||
)
|
)
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
|
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp, spec_to_mel, wav_to_mel, wav_to_spec
|
||||||
|
|
||||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
|
|
Loading…
Reference in New Issue