fix wavenet running with no input mask

This commit is contained in:
Eren Gölge 2021-03-02 12:18:35 +01:00 committed by Eren Gölge
parent 1c1949d348
commit a8cf1ae6b4
1 changed files with 2 additions and 1 deletions

View File

@ -91,6 +91,7 @@ class WN(torch.nn.Module):
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
x_mask = 1.0 if x_mask is None else x_mask
if g is not None:
g = self.cond_layer(g)
for i in range(self.num_layers):
@ -163,7 +164,7 @@ class WNBlocks(nn.Module):
weight_norm=weight_norm)
self.wn_blocks.append(layer)
def forward(self, x, x_mask, g=None):
def forward(self, x, x_mask=None, g=None):
o = x
for layer in self.wn_blocks:
o = layer(o, x_mask, g)