mirror of https://github.com/coqui-ai/TTS.git
formatting
This commit is contained in:
parent
b206162d11
commit
c0a2aa68d3
|
@ -84,30 +84,25 @@ class WN(torch.nn.Module):
|
|||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
|
||||
name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
# setup weight norm
|
||||
if not weight_norm:
|
||||
self.remove_weight_norm()
|
||||
|
||||
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
x_in = self.dropout(x_in)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:,
|
||||
cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
||||
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
|
||||
n_channels_tensor)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.num_layers - 1:
|
||||
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
|
||||
|
@ -154,6 +149,7 @@ class WNBlocks(nn.Module):
|
|||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
weight_norm=True):
|
||||
|
||||
super().__init__()
|
||||
self.wn_blocks = nn.ModuleList()
|
||||
for idx in range(num_blocks):
|
||||
|
|
Loading…
Reference in New Issue