From 79ed5debcd6c5b5adfa86183b35a6a05765a20e8 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Nov 2020 14:15:14 +0100 Subject: [PATCH] fix wavegrad test --- tests/test_wavegrad_train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_wavegrad_train.py b/tests/test_wavegrad_train.py index 604be401..700e94d1 100644 --- a/tests/test_wavegrad_train.py +++ b/tests/test_wavegrad_train.py @@ -1,5 +1,6 @@ import unittest +import numpy as np import torch from torch import optim from TTS.vocoder.models.wavegrad import Wavegrad @@ -33,7 +34,8 @@ class WavegradTrainTest(unittest.TestCase): [1, 2, 4, 8]]) model.train() model.to(device) - model.compute_noise_level(1000, 1e-6, 1e-2) + betas = np.linspace(1e-6, 1e-2, 1000) + model.compute_noise_level(betas) model_ref.load_state_dict(model.state_dict()) model_ref.to(device) count = 0