From e3647fa7b33e60597aad881e9ed8754f5e405a32 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 8 Apr 2019 18:28:19 +0200 Subject: [PATCH] bug fix for prenet setup --- layers/tacotron2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 4e41f366..e1aa02dd 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -57,12 +57,12 @@ class Prenet(nn.Module): super(Prenet, self).__init__() self.prenet_type = prenet_type in_features = [in_features] + out_features[:-1] - if prenet_type == "original": + if prenet_type == "bn": self.layers = nn.ModuleList([ LinearBN(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_features, out_features) ]) - elif prenet_type == "bn": + elif prenet_type == "original": self.layers = nn.ModuleList( [Linear(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_features, out_features) @@ -71,9 +71,9 @@ class Prenet(nn.Module): def forward(self, x): for linear in self.layers: if self.prenet_type == "original": - x = F.relu(linear(x)) - elif self.prenet_type == "bn": x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) + elif self.prenet_type == "bn": + x = F.relu(linear(x)) return x