mirror of https://github.com/coqui-ai/TTS.git
bug fix for prenet setup
This commit is contained in:
parent
303dc9d62e
commit
e3647fa7b3
|
@ -57,12 +57,12 @@ class Prenet(nn.Module):
|
|||
super(Prenet, self).__init__()
|
||||
self.prenet_type = prenet_type
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "original":
|
||||
if prenet_type == "bn":
|
||||
self.layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
elif prenet_type == "bn":
|
||||
elif prenet_type == "original":
|
||||
self.layers = nn.ModuleList(
|
||||
[Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
|
@ -71,9 +71,9 @@ class Prenet(nn.Module):
|
|||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
if self.prenet_type == "original":
|
||||
x = F.relu(linear(x))
|
||||
elif self.prenet_type == "bn":
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
elif self.prenet_type == "bn":
|
||||
x = F.relu(linear(x))
|
||||
return x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue