mirror of https://github.com/coqui-ai/TTS.git
pad_mode in torch_stft
This commit is contained in:
parent
599149a7e5
commit
4337e9ff87
|
@ -5,12 +5,14 @@ from torch.nn import functional as F
|
|||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
|
||||
# TODO: move this to audio.py with a transparent torch API.
|
||||
def __init__(self, n_fft, hop_length, win_length, pad_mode='reflect', window='hann_window'):
|
||||
""" Torch based STFT operation """
|
||||
super(TorchSTFT, self).__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_mode = pad_mode
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length),
|
||||
requires_grad=False)
|
||||
|
||||
|
@ -22,7 +24,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
pad_mode=self.pad_mode, # needs to be compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False)
|
||||
|
|
Loading…
Reference in New Issue