From c0a2aa68d3f0d4e5ce039de3902a40497d13e8e5 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 11 Jan 2021 17:25:39 +0100 Subject: [PATCH] formatting --- TTS/tts/layers/generic/wavenet.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index d0a1f3e9..9906aa4a 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -84,30 +84,25 @@ class WN(torch.nn.Module): res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') self.res_skip_layers.append(res_skip_layer) - + # setup weight norm if not weight_norm: self.remove_weight_norm() def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument output = torch.zeros_like(x) n_channels_tensor = torch.IntTensor([self.hidden_channels]) - if g is not None: g = self.cond_layer(g) - for i in range(self.num_layers): x_in = self.in_layers[i](x) x_in = self.dropout(x_in) if g is not None: cond_offset = i * 2 * self.hidden_channels - g_l = g[:, - cond_offset:cond_offset + 2 * self.hidden_channels, :] + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] else: g_l = torch.zeros_like(x_in) - acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) - res_skip_acts = self.res_skip_layers[i](acts) if i < self.num_layers - 1: x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask @@ -154,6 +149,7 @@ class WNBlocks(nn.Module): c_in_channels=0, dropout_p=0, weight_norm=True): + super().__init__() self.wn_blocks = nn.ModuleList() for idx in range(num_blocks):