Move `TorchSTFT` to `utils.audio`

This commit is contained in:
Eren Gölge 2021-06-21 16:50:37 +02:00
parent 5b89cb4fec
commit d700845b10
2 changed files with 79 additions and 78 deletions

View File

@ -3,12 +3,89 @@ import numpy as np
import scipy.io.wavfile import scipy.io.wavfile
import scipy.signal import scipy.signal
import soundfile as sf import soundfile as sf
import torch
from torch import nn
from TTS.tts.utils.data import StandardScaler from TTS.tts.utils.data import StandardScaler
# import pyworld as pw # import pyworld as pw
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [B x 1 x T]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
)
self.mel_basis = torch.from_numpy(mel_basis).float()
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
class AudioProcessor(object): class AudioProcessor(object):
"""Audio Processor for TTS used by all the data pipelines. """Audio Processor for TTS used by all the data pipelines.

View File

@ -1,88 +1,12 @@
from typing import Dict, Union from typing import Dict, Union
import librosa
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from TTS.utils.audio import TorchSTFT
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [B x 1 x T]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
)
self.mel_basis = torch.from_numpy(mel_basis).float()
################################# #################################
# GENERATOR LOSSES # GENERATOR LOSSES
################################# #################################
@ -275,7 +199,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
loss += total_loss loss += total_loss
real_loss += real_loss real_loss += real_loss
fake_loss += fake_loss fake_loss += fake_loss
# normalize loss values with number of scales # normalize loss values with number of scales (discriminators)
loss /= len(scores_fake) loss /= len(scores_fake)
real_loss /= len(scores_real) real_loss /= len(scores_real)
fake_loss /= len(scores_fake) fake_loss /= len(scores_fake)