formatting

This commit is contained in:
erogol 2021-01-11 17:25:39 +01:00
parent b206162d11
commit c0a2aa68d3
1 changed files with 3 additions and 7 deletions

View File

@ -84,30 +84,25 @@ class WN(torch.nn.Module):
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
name='weight')
self.res_skip_layers.append(res_skip_layer)
# setup weight norm
if not weight_norm:
self.remove_weight_norm()
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.num_layers):
x_in = self.in_layers[i](x)
x_in = self.dropout(x_in)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,
cond_offset:cond_offset + 2 * self.hidden_channels, :]
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
@ -154,6 +149,7 @@ class WNBlocks(nn.Module):
c_in_channels=0,
dropout_p=0,
weight_norm=True):
super().__init__()
self.wn_blocks = nn.ModuleList()
for idx in range(num_blocks):