mirror of https://github.com/coqui-ai/TTS.git
fix wavenet running with no input mask
This commit is contained in:
parent
1c1949d348
commit
a8cf1ae6b4
|
@ -91,6 +91,7 @@ class WN(torch.nn.Module):
|
|||
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
x_mask = 1.0 if x_mask is None else x_mask
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
for i in range(self.num_layers):
|
||||
|
@ -163,7 +164,7 @@ class WNBlocks(nn.Module):
|
|||
weight_norm=weight_norm)
|
||||
self.wn_blocks.append(layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
def forward(self, x, x_mask=None, g=None):
|
||||
o = x
|
||||
for layer in self.wn_blocks:
|
||||
o = layer(o, x_mask, g)
|
||||
|
|
Loading…
Reference in New Issue