From 4ffda89c42361af48c15721b2739b36c52f053d6 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 11 Mar 2019 13:03:43 +0100 Subject: [PATCH] reshape input vectors before and after bn layer --- layers/tacotron2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index bbed26ff..a311016e 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -44,9 +44,11 @@ class LinearBN(nn.Module): def forward(self, x): out = self.linear_layer(x) - out = out.transpose(0, 1).transpose(1, 2) + if len(out.shape)==3: + out = out.permute(1, 2, 0) out = self.bn(out) - out = out.transpose(1, 2).transpose(0, 1) + if len(out.shape) == 3: + out = out.permute(2, 0, 1) return out