diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 740b2a73..fab70594 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,23 +4,24 @@ from torch import nn from torch.nn import functional as F -class TorchSTFT(nn.Module): +class TorchSTFT(nn.Module): # pylint: disable=abstract-method """TODO: Merge this with audio.py""" def __init__(self, n_fft, hop_length, win_length, + pad_wav=False, window='hann_window', sample_rate=None, mel_fmin=0, mel_fmax=None, n_mels=80, use_mel=False): - """ Torch based STFT operation """ super(TorchSTFT, self).__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length + self.pad_wav = pad_wav self.sample_rate = sample_rate self.mel_fmin = mel_fmin self.mel_fmax = mel_fmax @@ -33,8 +34,22 @@ class TorchSTFT(nn.Module): self._build_mel_basis() def __call__(self, x): - padding = int((self.n_fft - self.hop_length) / 2) - x = torch.nn.functional.pad(x, (padding, padding), mode='reflect') + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [B x 1 x T] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode='reflect') # B x D x T x 2 o = torch.stft( x.squeeze(1),