Merge pull request #20 from eginhard/return-complex

fix: torch.stft will soon require return_complex=True
This commit is contained in:
Enno Hermann 2024-03-13 13:50:21 +01:00 committed by GitHub
commit e5c6da1c98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 77 deletions

View File

@ -179,17 +179,19 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.view_as_real(
y, torch.stft(
n_fft, y,
hop_length=hop_length, n_fft,
win_length=win_length, hop_length=hop_length,
window=hann_window[wnsize_dtype_device], win_length=win_length,
center=center, window=hann_window[wnsize_dtype_device],
pad_mode="reflect", center=center,
normalized=False, pad_mode="reflect",
onesided=True, normalized=False,
return_complex=False, onesided=True,
return_complex=True,
)
) )
return spec return spec
@ -274,17 +276,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.view_as_real(
y, torch.stft(
n_fft, y,
hop_length=hop_length, n_fft,
win_length=win_length, hop_length=hop_length,
window=hann_window[wnsize_dtype_device], win_length=win_length,
center=center, window=hann_window[wnsize_dtype_device],
pad_mode="reflect", center=center,
normalized=False, pad_mode="reflect",
onesided=True, normalized=False,
return_complex=False, onesided=True,
return_complex=True,
)
) )
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

View File

@ -121,17 +121,19 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.view_as_real(
y, torch.stft(
n_fft, y,
hop_length=hop_length, n_fft,
win_length=win_length, hop_length=hop_length,
window=hann_window[wnsize_dtype_device], win_length=win_length,
center=center, window=hann_window[wnsize_dtype_device],
pad_mode="reflect", center=center,
normalized=False, pad_mode="reflect",
onesided=True, normalized=False,
return_complex=False, onesided=True,
return_complex=True,
)
) )
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
@ -189,17 +191,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.view_as_real(
y, torch.stft(
n_fft, y,
hop_length=hop_length, n_fft,
win_length=win_length, hop_length=hop_length,
window=hann_window[wnsize_dtype_device], win_length=win_length,
center=center, window=hann_window[wnsize_dtype_device],
pad_mode="reflect", center=center,
normalized=False, pad_mode="reflect",
onesided=True, normalized=False,
return_complex=False, onesided=True,
return_complex=True,
)
) )
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

View File

@ -119,17 +119,19 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
padding = int((self.n_fft - self.hop_length) / 2) padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2 # B x D x T x 2
o = torch.stft( o = torch.view_as_real(
x.squeeze(1), torch.stft(
self.n_fft, x.squeeze(1),
self.hop_length, self.n_fft,
self.win_length, self.hop_length,
self.window, self.win_length,
center=True, self.window,
pad_mode="reflect", # compatible with audio.py center=True,
normalized=self.normalized, pad_mode="reflect", # compatible with audio.py
onesided=True, normalized=self.normalized,
return_complex=False, onesided=True,
return_complex=True,
)
) )
M = o[:, :, :, 0] M = o[:, :, :, 0]
P = o[:, :, :, 1] P = o[:, :, :, 1]

View File

@ -54,17 +54,19 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.view_as_real(
y, torch.stft(
n_fft, y,
hop_length=hop_size, n_fft,
win_length=win_size, hop_length=hop_size,
window=hann_window[wnsize_dtype_device], win_length=win_size,
center=center, window=hann_window[wnsize_dtype_device],
pad_mode="reflect", center=center,
normalized=False, pad_mode="reflect",
onesided=True, normalized=False,
return_complex=False, onesided=True,
return_complex=True,
)
) )
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
@ -104,17 +106,19 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size,
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.view_as_real(
y, torch.stft(
n_fft, y,
hop_length=hop_size, n_fft,
win_length=win_size, hop_length=hop_size,
window=hann_window[wnsize_dtype_device], win_length=win_size,
center=center, window=hann_window[wnsize_dtype_device],
pad_mode="reflect", center=center,
normalized=False, pad_mode="reflect",
onesided=True, normalized=False,
return_complex=False, onesided=True,
return_complex=True,
)
) )
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)