From 76df6421dead004a40b1ded1b12916282f013132 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 23 Nov 2024 01:16:50 +0100 Subject: [PATCH] refactor: move more audio processing into torch_transforms --- TTS/tts/models/delightful_tts.py | 139 +----------------------- TTS/tts/models/vits.py | 126 +-------------------- TTS/utils/audio/torch_transforms.py | 96 ++++++++++++++++ TTS/vc/modules/freevc/mel_processing.py | 48 -------- tests/tts_tests/test_vits.py | 5 +- 5 files changed, 100 insertions(+), 314 deletions(-) diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 88570047..e6db1160 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -9,7 +9,6 @@ import numpy as np import torch import torch.distributed as dist from coqpit import Coqpit -from librosa.filters import mel as librosa_mel_fn from torch import nn from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler @@ -38,7 +37,7 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.utils.audio.processor import AudioProcessor -from TTS.utils.audio.torch_transforms import amp_to_db +from TTS.utils.audio.torch_transforms import wav_to_mel, wav_to_spec from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -50,145 +49,11 @@ hann_window = {} mel_basis = {} -def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global hann_window # pylint: disable=global-statement - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - return spec - - -def wav_to_spec(y, n_fft, hop_length, win_length, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - def wav_to_energy(y, n_fft, hop_length, win_length, center=False): - spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center) return torch.norm(spec, dim=1, keepdim=True) -def name_mel_basis(spec, n_fft, fmax): - n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" - return n_fft_len - - -def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): - """ - Args Shapes: - - spec : :math:`[B,C,T]` - - Return Shapes: - - mel : :math:`[B,C,T]` - """ - global mel_basis # pylint: disable=global-statement - mel_basis_key = name_mel_basis(spec, n_fft, fmax) - # pylint: disable=too-many-function-args - if mel_basis_key not in mel_basis: - # pylint: disable=missing-kwoa - mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) - mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - mel = torch.matmul(mel_basis[mel_basis_key], spec) - mel = amp_to_db(mel) - return mel - - -def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T_y]` - - Return Shapes: - - spec : :math:`[B,C,T_spec]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global mel_basis, hann_window # pylint: disable=global-statement - mel_basis_key = name_mel_basis(y, n_fft, fmax) - wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) - if mel_basis_key not in mel_basis: - # pylint: disable=missing-kwoa - mel = librosa_mel_fn( - sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax - ) # pylint: disable=too-many-function-args - mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - spec = torch.matmul(mel_basis[mel_basis_key], spec) - spec = amp_to_db(spec) - return spec - - ############################## # DATASET ############################## diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 30d9caff..7ec25192 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -10,7 +10,6 @@ import torch import torch.distributed as dist import torchaudio from coqpit import Coqpit -from librosa.filters import mel as librosa_mel_fn from monotonic_alignment_search import maximum_path from torch import nn from torch.nn import functional as F @@ -35,7 +34,7 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment -from TTS.utils.audio.torch_transforms import amp_to_db +from TTS.utils.audio.torch_transforms import spec_to_mel, wav_to_mel, wav_to_spec from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -46,10 +45,6 @@ logger = logging.getLogger(__name__) # IO / Feature extraction ############################## -# pylint: disable=global-statement -hann_window = {} -mel_basis = {} - @torch.no_grad() def weights_reset(m: nn.Module): @@ -79,125 +74,6 @@ def load_audio(file_path): return x, sr -def wav_to_spec(y, n_fft, hop_length, win_length, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - -def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): - """ - Args Shapes: - - spec : :math:`[B,C,T]` - - Return Shapes: - - mel : :math:`[B,C,T]` - """ - global mel_basis - dtype_device = str(spec.dtype) + "_" + str(spec.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - mel = torch.matmul(mel_basis[fmax_dtype_device], spec) - mel = amp_to_db(mel) - return mel - - -def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global mel_basis, hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = amp_to_db(spec) - return spec - - ############################# # CONFIGS ############################# diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index dda4c0a4..59bb23cc 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -1,7 +1,15 @@ +import logging + import librosa import torch from torch import nn +logger = logging.getLogger(__name__) + + +hann_window = {} +mel_basis = {} + def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor: """Spectral normalization / dynamic range compression.""" @@ -13,6 +21,94 @@ def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor: return torch.exp(x) / spec_gain +def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor: + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + logger.info("min value is %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("max value is %.3f", torch.max(y)) + + global hann_window + wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}" + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + return torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + +def spec_to_mel( + spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float +) -> torch.Tensor: + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" + if fmax_dtype_device not in mel_basis: + # TODO: switch librosa to torchaudio + mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + return amp_to_db(mel) + + +def wav_to_mel( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sample_rate: int, + hop_length: int, + win_length: int, + fmin: float, + fmax: float, + *, + center: bool = False, +) -> torch.Tensor: + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center) + return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax) + + class TorchSTFT(nn.Module): # pylint: disable=abstract-method """Some of the audio processing funtions using Torch for faster batch processing. diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index 4da5e27c..017d9002 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -14,54 +14,6 @@ mel_basis = {} hann_window = {} -def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): - if torch.min(y) < -1.0: - logger.info("Min value is: %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("Max value is: %.3f", torch.max(y)) - - global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_size) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - -def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): - global mel_basis - dtype_device = str(spec.dtype) + "_" + str(spec.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = amp_to_db(spec) - return spec - - def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): if torch.min(y) < -1.0: logger.info("Min value is: %.3f", torch.min(y)) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index a27bdfe5..c8a52e1c 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -14,12 +14,9 @@ from TTS.tts.models.vits import ( VitsArgs, VitsAudioConfig, load_audio, - spec_to_mel, - wav_to_mel, - wav_to_spec, ) from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp +from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp, spec_to_mel, wav_to_mel, wav_to_spec LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")