mirror of https://github.com/coqui-ai/TTS.git
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
from math import sqrt
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
init_gain='linear'):
|
|
super(Linear, self).__init__()
|
|
self.linear_layer = torch.nn.Linear(
|
|
in_features, out_features, bias=bias)
|
|
self._init_w(init_gain)
|
|
|
|
def _init_w(self, init_gain):
|
|
torch.nn.init.xavier_uniform_(
|
|
self.linear_layer.weight,
|
|
gain=torch.nn.init.calculate_gain(init_gain))
|
|
|
|
def forward(self, x):
|
|
return self.linear_layer(x)
|
|
|
|
|
|
class LinearBN(nn.Module):
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
init_gain='linear'):
|
|
super(LinearBN, self).__init__()
|
|
self.linear_layer = torch.nn.Linear(
|
|
in_features, out_features, bias=bias)
|
|
self.bn = nn.BatchNorm1d(out_features)
|
|
self._init_w(init_gain)
|
|
|
|
def _init_w(self, init_gain):
|
|
torch.nn.init.xavier_uniform_(
|
|
self.linear_layer.weight,
|
|
gain=torch.nn.init.calculate_gain(init_gain))
|
|
|
|
def forward(self, x):
|
|
out = self.linear_layer(x)
|
|
if len(out.shape) == 3:
|
|
out = out.permute(1, 2, 0)
|
|
out = self.bn(out)
|
|
if len(out.shape) == 3:
|
|
out = out.permute(2, 0, 1)
|
|
return out
|
|
|
|
|
|
class Prenet(nn.Module):
|
|
def __init__(self,
|
|
in_features,
|
|
prenet_type="original",
|
|
prenet_dropout=True,
|
|
out_features=[256, 256],
|
|
bias=True):
|
|
super(Prenet, self).__init__()
|
|
self.prenet_type = prenet_type
|
|
self.prenet_dropout = prenet_dropout
|
|
in_features = [in_features] + out_features[:-1]
|
|
if prenet_type == "bn":
|
|
self.layers = nn.ModuleList([
|
|
LinearBN(in_size, out_size, bias=bias)
|
|
for (in_size, out_size) in zip(in_features, out_features)
|
|
])
|
|
elif prenet_type == "original":
|
|
self.layers = nn.ModuleList([
|
|
Linear(in_size, out_size, bias=bias)
|
|
for (in_size, out_size) in zip(in_features, out_features)
|
|
])
|
|
|
|
def forward(self, x):
|
|
for linear in self.layers:
|
|
if self.prenet_dropout:
|
|
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
|
else:
|
|
x = F.relu(linear(x))
|
|
return x |