TorchSTFT to device fix

This commit is contained in:
gerazov 2021-01-16 12:21:16 +01:00
parent 7beaacc55b
commit c96f7a2614
1 changed files with 5 additions and 3 deletions

View File

@ -4,13 +4,14 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
class TorchSTFT(): class TorchSTFT(nn.Module):
def __init__(self, n_fft, hop_length, win_length, window='hann_window'): def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
""" Torch based STFT operation """ """ Torch based STFT operation """
super(TorchSTFT, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
self.hop_length = hop_length self.hop_length = hop_length
self.win_length = win_length self.win_length = win_length
self.window = getattr(torch, window)(win_length) self.window = nn.Parameter(getattr(torch, window)(win_length))
def __call__(self, x): def __call__(self, x):
# B x D x T x 2 # B x D x T x 2
@ -22,7 +23,8 @@ class TorchSTFT():
center=True, center=True,
pad_mode="reflect", # compatible with audio.py pad_mode="reflect", # compatible with audio.py
normalized=False, normalized=False,
onesided=True) onesided=True,
return_complex=False)
M = o[:, :, :, 0] M = o[:, :, :, 0]
P = o[:, :, :, 1] P = o[:, :, :, 1]
return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))