diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index dbfa9f69..97eee879 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -91,6 +91,7 @@ class WN(torch.nn.Module): def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument output = torch.zeros_like(x) n_channels_tensor = torch.IntTensor([self.hidden_channels]) + x_mask = 1.0 if x_mask is None else x_mask if g is not None: g = self.cond_layer(g) for i in range(self.num_layers): @@ -163,7 +164,7 @@ class WNBlocks(nn.Module): weight_norm=weight_norm) self.wn_blocks.append(layer) - def forward(self, x, x_mask, g=None): + def forward(self, x, x_mask=None, g=None): o = x for layer in self.wn_blocks: o = layer(o, x_mask, g)