docstring and optional padding in TorchSTFT

This commit is contained in:
Eren Gölge 2021-04-07 12:36:15 +02:00
parent f890454de3
commit c86c559349
1 changed files with 19 additions and 4 deletions

View File

@ -4,23 +4,24 @@ from torch import nn
from torch.nn import functional as F
class TorchSTFT(nn.Module):
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py"""
def __init__(self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window='hann_window',
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False):
""" Torch based STFT operation """
super(TorchSTFT, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
@ -33,8 +34,22 @@ class TorchSTFT(nn.Module):
self._build_mel_basis()
def __call__(self, x):
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [B x 1 x T]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
# B x D x T x 2
o = torch.stft(
x.squeeze(1),