diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index 50ed1024..ab2ca566 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -256,7 +256,7 @@ class PosteriorEncoder(nn.Module): ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, g=None): + def forward(self, x, x_lengths, g=None, tau=1.0): """ Shapes: - x: :math:`[B, C, T]` @@ -268,5 +268,5 @@ class PosteriorEncoder(nn.Module): x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask mean, log_scale = torch.split(stats, self.out_channels, dim=1) - z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask return z, mean, log_scale, x_mask