mirror of https://github.com/coqui-ai/TTS.git
feat(vits): add tau parameter to posterior encoder
This commit is contained in:
parent
6de98ff480
commit
4124b9d663
|
@ -256,7 +256,7 @@ class PosteriorEncoder(nn.Module):
|
|||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
def forward(self, x, x_lengths, g=None, tau=1.0):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
|
@ -268,5 +268,5 @@ class PosteriorEncoder(nn.Module):
|
|||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
|
||||
z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask
|
||||
return z, mean, log_scale, x_mask
|
||||
|
|
Loading…
Reference in New Issue