diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index aeb45c7b..613ad19d 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -67,9 +67,14 @@ class WN(torch.nn.Module): for i in range(num_layers): dilation = dilation_rate**i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d( - hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding - ) + if i == 0: + in_layer = torch.nn.Conv1d( + in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + else: + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index ff1b99e8..3b745018 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -197,7 +197,7 @@ class CouplingBlock(nn.Module): end.bias.data.zero_() self.end = end # coupling layers - self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) + self.wn = WN(hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument """