From 4124b9d663b4eea5e7034e96351fe5d4180cfb89 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 25 Jun 2024 22:28:41 +0200 Subject: [PATCH] feat(vits): add tau parameter to posterior encoder --- TTS/tts/layers/vits/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index 50ed1024..ab2ca566 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -256,7 +256,7 @@ class PosteriorEncoder(nn.Module): ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, g=None): + def forward(self, x, x_lengths, g=None, tau=1.0): """ Shapes: - x: :math:`[B, C, T]` @@ -268,5 +268,5 @@ class PosteriorEncoder(nn.Module): x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask mean, log_scale = torch.split(stats, self.out_channels, dim=1) - z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask return z, mean, log_scale, x_mask