mirror of https://github.com/coqui-ai/TTS.git
use BN for prenet
This commit is contained in:
parent
cf11b6c23c
commit
a144acf466
|
@ -63,4 +63,4 @@
|
|||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners"
|
||||
}
|
||||
}
|
|
@ -25,20 +25,43 @@ class Linear(nn.Module):
|
|||
return self.linear_layer(x)
|
||||
|
||||
|
||||
class LinearBN(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(LinearBN, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self.bn = nn.BatchNorm1d(out_features)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
out = self.linear_layer(x)
|
||||
out = out.transpose(0, 1).transpose(1, 2)
|
||||
out = self.bn(out)
|
||||
out = out.transpose(1, 2).transpose(0, 1)
|
||||
return out
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self, in_features, out_features=[256, 256]):
|
||||
super(Prenet, self).__init__()
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
self.layers = nn.ModuleList([
|
||||
Linear(in_size, out_size, bias=False)
|
||||
LinearBN(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
# Prenet uses dropout also at inference time. Otherwise,
|
||||
# it degrades the inference time attention.
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||
x = F.relu(linear(x))
|
||||
return x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue