bug fix for prenet setup

This commit is contained in:
Eren Golge 2019-04-08 18:28:19 +02:00
parent 303dc9d62e
commit e3647fa7b3
1 changed files with 4 additions and 4 deletions

View File

@ -57,12 +57,12 @@ class Prenet(nn.Module):
super(Prenet, self).__init__()
self.prenet_type = prenet_type
in_features = [in_features] + out_features[:-1]
if prenet_type == "original":
if prenet_type == "bn":
self.layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
])
elif prenet_type == "bn":
elif prenet_type == "original":
self.layers = nn.ModuleList(
[Linear(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
@ -71,9 +71,9 @@ class Prenet(nn.Module):
def forward(self, x):
for linear in self.layers:
if self.prenet_type == "original":
x = F.relu(linear(x))
elif self.prenet_type == "bn":
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
elif self.prenet_type == "bn":
x = F.relu(linear(x))
return x