from math import sqrt import torch from torch.autograd import Variable from torch import nn from torch.nn import functional as F class Linear(nn.Module): def __init__(self, in_features, out_features, bias=True, init_gain='linear'): super(Linear, self).__init__() self.linear_layer = torch.nn.Linear( in_features, out_features, bias=bias) self._init_w(init_gain) def _init_w(self, init_gain): torch.nn.init.xavier_uniform_( self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) def forward(self, x): return self.linear_layer(x) class LinearBN(nn.Module): def __init__(self, in_features, out_features, bias=True, init_gain='linear'): super(LinearBN, self).__init__() self.linear_layer = torch.nn.Linear( in_features, out_features, bias=bias) self.bn = nn.BatchNorm1d(out_features) self._init_w(init_gain) def _init_w(self, init_gain): torch.nn.init.xavier_uniform_( self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) def forward(self, x): out = self.linear_layer(x) if len(out.shape) == 3: out = out.permute(1, 2, 0) out = self.bn(out) if len(out.shape) == 3: out = out.permute(2, 0, 1) return out class Prenet(nn.Module): def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True): super(Prenet, self).__init__() self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout in_features = [in_features] + out_features[:-1] if prenet_type == "bn": self.layers = nn.ModuleList([ LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features) ]) elif prenet_type == "original": self.layers = nn.ModuleList([ Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features) ]) def forward(self, x): for linear in self.layers: if self.prenet_dropout: x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) else: x = F.relu(linear(x)) return x