mirror of https://github.com/coqui-ai/TTS.git
feat(vits): add tau parameter to posterior encoder
This commit is contained in:
parent
6de98ff480
commit
4124b9d663
|
@ -256,7 +256,7 @@ class PosteriorEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None, tau=1.0):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
|
@ -268,5 +268,5 @@ class PosteriorEncoder(nn.Module):
|
||||||
x = self.enc(x, x_mask, g=g)
|
x = self.enc(x, x_mask, g=g)
|
||||||
stats = self.proj(x) * x_mask
|
stats = self.proj(x) * x_mask
|
||||||
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
|
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
|
||||||
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
|
z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask
|
||||||
return z, mean, log_scale, x_mask
|
return z, mean, log_scale, x_mask
|
||||||
|
|
Loading…
Reference in New Issue