feat(vits): add tau parameter to posterior encoder

This commit is contained in:
Enno Hermann 2024-06-25 22:28:41 +02:00
parent 6de98ff480
commit 4124b9d663
1 changed files with 2 additions and 2 deletions

View File

@ -256,7 +256,7 @@ class PosteriorEncoder(nn.Module):
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
def forward(self, x, x_lengths, g=None, tau=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
@ -268,5 +268,5 @@ class PosteriorEncoder(nn.Module):
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask