pad_mode in torch_stft

This commit is contained in:
Eren Gölge 2021-03-10 14:41:00 +01:00
parent 599149a7e5
commit 4337e9ff87
1 changed files with 4 additions and 2 deletions

View File

@ -5,12 +5,14 @@ from torch.nn import functional as F
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
# TODO: move this to audio.py with a transparent torch API.
def __init__(self, n_fft, hop_length, win_length, pad_mode='reflect', window='hann_window'):
""" Torch based STFT operation """
super(TorchSTFT, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_mode = pad_mode
self.window = nn.Parameter(getattr(torch, window)(win_length),
requires_grad=False)
@ -22,7 +24,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
pad_mode=self.pad_mode, # needs to be compatible with audio.py
normalized=False,
onesided=True,
return_complex=False)