mirror of https://github.com/coqui-ai/TTS.git
fix wavegrad test
This commit is contained in:
parent
a2a142dc39
commit
79ed5debcd
|
@ -1,5 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||||
|
@ -33,7 +34,8 @@ class WavegradTrainTest(unittest.TestCase):
|
||||||
[1, 2, 4, 8]])
|
[1, 2, 4, 8]])
|
||||||
model.train()
|
model.train()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.compute_noise_level(1000, 1e-6, 1e-2)
|
betas = np.linspace(1e-6, 1e-2, 1000)
|
||||||
|
model.compute_noise_level(betas)
|
||||||
model_ref.load_state_dict(model.state_dict())
|
model_ref.load_state_dict(model.state_dict())
|
||||||
model_ref.to(device)
|
model_ref.to(device)
|
||||||
count = 0
|
count = 0
|
||||||
|
|
Loading…
Reference in New Issue