mirror of https://github.com/coqui-ai/TTS.git
use BN for prenet
This commit is contained in:
parent
cf11b6c23c
commit
a144acf466
|
@ -63,4 +63,4 @@
|
||||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||||
"text_cleaner": "phoneme_cleaners"
|
"text_cleaner": "phoneme_cleaners"
|
||||||
}
|
}
|
|
@ -25,20 +25,43 @@ class Linear(nn.Module):
|
||||||
return self.linear_layer(x)
|
return self.linear_layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearBN(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
bias=True,
|
||||||
|
init_gain='linear'):
|
||||||
|
super(LinearBN, self).__init__()
|
||||||
|
self.linear_layer = torch.nn.Linear(
|
||||||
|
in_features, out_features, bias=bias)
|
||||||
|
self.bn = nn.BatchNorm1d(out_features)
|
||||||
|
self._init_w(init_gain)
|
||||||
|
|
||||||
|
def _init_w(self, init_gain):
|
||||||
|
torch.nn.init.xavier_uniform_(
|
||||||
|
self.linear_layer.weight,
|
||||||
|
gain=torch.nn.init.calculate_gain(init_gain))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.linear_layer(x)
|
||||||
|
out = out.transpose(0, 1).transpose(1, 2)
|
||||||
|
out = self.bn(out)
|
||||||
|
out = out.transpose(1, 2).transpose(0, 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
def __init__(self, in_features, out_features=[256, 256]):
|
def __init__(self, in_features, out_features=[256, 256]):
|
||||||
super(Prenet, self).__init__()
|
super(Prenet, self).__init__()
|
||||||
in_features = [in_features] + out_features[:-1]
|
in_features = [in_features] + out_features[:-1]
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Linear(in_size, out_size, bias=False)
|
LinearBN(in_size, out_size, bias=False)
|
||||||
for (in_size, out_size) in zip(in_features, out_features)
|
for (in_size, out_size) in zip(in_features, out_features)
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for linear in self.layers:
|
for linear in self.layers:
|
||||||
# Prenet uses dropout also at inference time. Otherwise,
|
x = F.relu(linear(x))
|
||||||
# it degrades the inference time attention.
|
|
||||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue