diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 4e41f366..e1aa02dd 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -57,12 +57,12 @@ class Prenet(nn.Module): super(Prenet, self).__init__() self.prenet_type = prenet_type in_features = [in_features] + out_features[:-1] - if prenet_type == "original": + if prenet_type == "bn": self.layers = nn.ModuleList([ LinearBN(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_features, out_features) ]) - elif prenet_type == "bn": + elif prenet_type == "original": self.layers = nn.ModuleList( [Linear(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_features, out_features) @@ -71,9 +71,9 @@ class Prenet(nn.Module): def forward(self, x): for linear in self.layers: if self.prenet_type == "original": - x = F.relu(linear(x)) - elif self.prenet_type == "bn": x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) + elif self.prenet_type == "bn": + x = F.relu(linear(x)) return x