diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 34c2f9b7..34165cc7 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -5,12 +5,14 @@ from torch.nn import functional as F class TorchSTFT(nn.Module): # pylint: disable=abstract-method - def __init__(self, n_fft, hop_length, win_length, window='hann_window'): + # TODO: move this to audio.py with a transparent torch API. + def __init__(self, n_fft, hop_length, win_length, pad_mode='reflect', window='hann_window'): """ Torch based STFT operation """ super(TorchSTFT, self).__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length + self.pad_mode = pad_mode self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) @@ -22,7 +24,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method self.win_length, self.window, center=True, - pad_mode="reflect", # compatible with audio.py + pad_mode=self.pad_mode, # needs to be compatible with audio.py normalized=False, onesided=True, return_complex=False)