Fix VITS stochastic duration predictor

This commit is contained in:
Eren Gölge 2021-11-08 09:20:11 +01:00
parent 3a77899775
commit b6b14a76af
1 changed files with 1 additions and 1 deletions

View File

@ -266,7 +266,7 @@ class StochasticDurationPredictor(nn.Module):
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = torch.flip(z, [1])
z = flow(z, x_mask, g=x, reverse=reverse)