use BN for prenet

This commit is contained in:
Eren Golge 2019-03-07 16:28:50 +01:00
parent cf11b6c23c
commit a144acf466
2 changed files with 28 additions and 5 deletions

View File

@ -63,4 +63,4 @@
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners"
}
}

View File

@ -25,20 +25,43 @@ class Linear(nn.Module):
return self.linear_layer(x)
class LinearBN(nn.Module):
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(LinearBN, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
self.bn = nn.BatchNorm1d(out_features)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
out = self.linear_layer(x)
out = out.transpose(0, 1).transpose(1, 2)
out = self.bn(out)
out = out.transpose(1, 2).transpose(0, 1)
return out
class Prenet(nn.Module):
def __init__(self, in_features, out_features=[256, 256]):
super(Prenet, self).__init__()
in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList([
Linear(in_size, out_size, bias=False)
LinearBN(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
])
def forward(self, x):
for linear in self.layers:
# Prenet uses dropout also at inference time. Otherwise,
# it degrades the inference time attention.
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
x = F.relu(linear(x))
return x