feat(vits): add tau parameter to posterior encoder

This commit is contained in:
Enno Hermann 2024-06-25 22:28:41 +02:00
parent 6de98ff480
commit 4124b9d663
1 changed files with 2 additions and 2 deletions

View File

@ -256,7 +256,7 @@ class PosteriorEncoder(nn.Module):
) )
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None, tau=1.0):
""" """
Shapes: Shapes:
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
@ -268,5 +268,5 @@ class PosteriorEncoder(nn.Module):
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1) mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask return z, mean, log_scale, x_mask