mirror of https://github.com/coqui-ai/TTS.git
Update wavegrad.py
This should fix the issue https://github.com/mozilla/TTS/issues/581
This commit is contained in:
parent
ac46c3ff4c
commit
f42ca2b73f
|
@ -105,8 +105,8 @@ class Wavegrad(nn.Module):
|
||||||
self.noise_level = self.noise_level.to(y_0)
|
self.noise_level = self.noise_level.to(y_0)
|
||||||
if len(y_0.shape) == 3:
|
if len(y_0.shape) == 3:
|
||||||
y_0 = y_0.squeeze(1)
|
y_0 = y_0.squeeze(1)
|
||||||
s = torch.randint(1, self.num_steps + 1, [y_0.shape[0]])
|
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
|
||||||
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
|
l_a, l_b = self.noise_level[s], self.noise_level[s+1]
|
||||||
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
|
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
|
||||||
noise_scale = noise_scale.unsqueeze(1)
|
noise_scale = noise_scale.unsqueeze(1)
|
||||||
noise = torch.randn_like(y_0)
|
noise = torch.randn_like(y_0)
|
||||||
|
|
Loading…
Reference in New Issue