diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index d49f2725..1107b3c5 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -11,7 +11,8 @@ class TorchSTFT(nn.Module): self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length - self.window = nn.Parameter(getattr(torch, window)(win_length)) + self.window = nn.Parameter(getattr(torch, window)(win_length), + requires_grad=False) def __call__(self, x): # B x D x T x 2