mirror of https://github.com/coqui-ai/TTS.git
update tests
This commit is contained in:
parent
7c671cae5a
commit
042cde15d6
|
@ -20,7 +20,7 @@ ap = AudioProcessor(**C.audio)
|
|||
|
||||
|
||||
def test_torch_stft():
|
||||
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
|
||||
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
||||
# librosa stft
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
||||
|
@ -32,7 +32,7 @@ def test_torch_stft():
|
|||
|
||||
|
||||
def test_stft_loss():
|
||||
stft_loss = STFTLoss(ap.n_fft, ap.hop_length, ap.win_length)
|
||||
stft_loss = STFTLoss(ap.fft_size, ap.hop_length, ap.win_length)
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
wav = torch.from_numpy(wav[None, :]).float()
|
||||
loss_m, loss_sc = stft_loss(wav, wav)
|
||||
|
@ -43,7 +43,7 @@ def test_stft_loss():
|
|||
|
||||
|
||||
def test_multiscale_stft_loss():
|
||||
stft_loss = MultiScaleSTFTLoss([ap.n_fft//2, ap.n_fft, ap.n_fft*2],
|
||||
stft_loss = MultiScaleSTFTLoss([ap.fft_size//2, ap.fft_size, ap.fft_size*2],
|
||||
[ap.hop_length // 2, ap.hop_length, ap.hop_length * 2],
|
||||
[ap.win_length // 2, ap.win_length, ap.win_length * 2])
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
|
|
Loading…
Reference in New Issue