mirror of https://github.com/coqui-ai/TTS.git
docstring and optional padding in TorchSTFT
This commit is contained in:
parent
f890454de3
commit
c86c559349
|
@ -4,23 +4,24 @@ from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module):
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""TODO: Merge this with audio.py"""
|
"""TODO: Merge this with audio.py"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_fft,
|
n_fft,
|
||||||
hop_length,
|
hop_length,
|
||||||
win_length,
|
win_length,
|
||||||
|
pad_wav=False,
|
||||||
window='hann_window',
|
window='hann_window',
|
||||||
sample_rate=None,
|
sample_rate=None,
|
||||||
mel_fmin=0,
|
mel_fmin=0,
|
||||||
mel_fmax=None,
|
mel_fmax=None,
|
||||||
n_mels=80,
|
n_mels=80,
|
||||||
use_mel=False):
|
use_mel=False):
|
||||||
""" Torch based STFT operation """
|
|
||||||
super(TorchSTFT, self).__init__()
|
super(TorchSTFT, self).__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
self.win_length = win_length
|
self.win_length = win_length
|
||||||
|
self.pad_wav = pad_wav
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.mel_fmin = mel_fmin
|
self.mel_fmin = mel_fmin
|
||||||
self.mel_fmax = mel_fmax
|
self.mel_fmax = mel_fmax
|
||||||
|
@ -33,6 +34,20 @@ class TorchSTFT(nn.Module):
|
||||||
self._build_mel_basis()
|
self._build_mel_basis()
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
"""Compute spectrogram frames by torch based stft.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: spectrogram frames.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
x: [B x T] or [B x 1 x T]
|
||||||
|
"""
|
||||||
|
if x.ndim == 2:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
if self.pad_wav:
|
||||||
padding = int((self.n_fft - self.hop_length) / 2)
|
padding = int((self.n_fft - self.hop_length) / 2)
|
||||||
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
|
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
|
||||||
# B x D x T x 2
|
# B x D x T x 2
|
||||||
|
|
Loading…
Reference in New Issue