reshape input vectors before and after bn layer

This commit is contained in:
Eren Golge 2019-03-11 13:03:43 +01:00
parent a144acf466
commit 4ffda89c42
1 changed files with 4 additions and 2 deletions

View File

@ -44,9 +44,11 @@ class LinearBN(nn.Module):
def forward(self, x):
out = self.linear_layer(x)
out = out.transpose(0, 1).transpose(1, 2)
if len(out.shape)==3:
out = out.permute(1, 2, 0)
out = self.bn(out)
out = out.transpose(1, 2).transpose(0, 1)
if len(out.shape) == 3:
out = out.permute(2, 0, 1)
return out