diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 6770a98e..70ca4dbf 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -131,9 +131,8 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): pad_mode="reflect", normalized=False, onesided=True, - return_complex=False, ) - + spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -199,7 +198,6 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm pad_mode="reflect", normalized=False, onesided=True, - return_complex=False, ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index fd40ebb0..6eae37dc 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -129,8 +129,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method pad_mode="reflect", # compatible with audio.py normalized=self.normalized, onesided=True, - return_complex=False, ) + o = torch.view_as_real(o) M = o[:, :, :, 0] P = o[:, :, :, 1] S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index 2dcbf214..3b421942 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -64,9 +64,8 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) pad_mode="reflect", normalized=False, onesided=True, - return_complex=False, ) - + spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -114,9 +113,8 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, pad_mode="reflect", normalized=False, onesided=True, - return_complex=False, ) - + spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.matmul(mel_basis[fmax_dtype_device], spec)