From c96f7a2614ae336ac8b4c1657af444846719cd83 Mon Sep 17 00:00:00 2001 From: gerazov Date: Sat, 16 Jan 2021 12:21:16 +0100 Subject: [PATCH] TorchSTFT to device fix --- TTS/vocoder/layers/losses.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index e705b1e0..d49f2725 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,13 +4,14 @@ from torch import nn from torch.nn import functional as F -class TorchSTFT(): +class TorchSTFT(nn.Module): def __init__(self, n_fft, hop_length, win_length, window='hann_window'): """ Torch based STFT operation """ + super(TorchSTFT, self).__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length - self.window = getattr(torch, window)(win_length) + self.window = nn.Parameter(getattr(torch, window)(win_length)) def __call__(self, x): # B x D x T x 2 @@ -22,7 +23,8 @@ class TorchSTFT(): center=True, pad_mode="reflect", # compatible with audio.py normalized=False, - onesided=True) + onesided=True, + return_complex=False) M = o[:, :, :, 0] P = o[:, :, :, 1] return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))