mirror of https://github.com/coqui-ai/TTS.git
fix wavenet running with no input mask
This commit is contained in:
parent
1c1949d348
commit
a8cf1ae6b4
|
@ -91,6 +91,7 @@ class WN(torch.nn.Module):
|
||||||
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
|
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
|
||||||
output = torch.zeros_like(x)
|
output = torch.zeros_like(x)
|
||||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||||
|
x_mask = 1.0 if x_mask is None else x_mask
|
||||||
if g is not None:
|
if g is not None:
|
||||||
g = self.cond_layer(g)
|
g = self.cond_layer(g)
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
|
@ -163,7 +164,7 @@ class WNBlocks(nn.Module):
|
||||||
weight_norm=weight_norm)
|
weight_norm=weight_norm)
|
||||||
self.wn_blocks.append(layer)
|
self.wn_blocks.append(layer)
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None):
|
def forward(self, x, x_mask=None, g=None):
|
||||||
o = x
|
o = x
|
||||||
for layer in self.wn_blocks:
|
for layer in self.wn_blocks:
|
||||||
o = layer(o, x_mask, g)
|
o = layer(o, x_mask, g)
|
||||||
|
|
Loading…
Reference in New Issue