mirror of https://github.com/coqui-ai/TTS.git
small fixes
This commit is contained in:
parent
5901a00576
commit
eb555855e4
|
@ -15,7 +15,6 @@ class PositionalEncoding(nn.Module):
|
||||||
dropout (float): dropout parameter
|
dropout (float): dropout parameter
|
||||||
dim (int): embedding size
|
dim (int): embedding size
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, dropout=0.0, max_len=5000):
|
def __init__(self, dim, dropout=0.0, max_len=5000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if dim % 2 != 0:
|
if dim % 2 != 0:
|
||||||
|
@ -117,33 +116,36 @@ class Encoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.hidden_channels = hidden_channels
|
self.in_channels = in_hidden_channels
|
||||||
|
self.hidden_channels = in_hidden_channels
|
||||||
self.encoder_type = encoder_type
|
self.encoder_type = encoder_type
|
||||||
self.c_in_channels = c_in_channels
|
self.c_in_channels = c_in_channels
|
||||||
|
|
||||||
# init encoder
|
# init encoder
|
||||||
if encoder_type.lower() == "transformer":
|
if encoder_type.lower() == "transformer":
|
||||||
# optional convolutional prenet
|
# optional convolutional prenet
|
||||||
self.pre = ConvLayerNorm(hidden_channels,
|
self.pre = ConvLayerNorm(self.in_channels,
|
||||||
hidden_channels,
|
self.hidden_channels,
|
||||||
hidden_channels,
|
self.hidden_channels,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
num_layers=3,
|
num_layers=3,
|
||||||
dropout_p=0.5)
|
dropout_p=0.5)
|
||||||
# text encoder
|
# text encoder
|
||||||
self.encoder = Transformer(hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
self.encoder = Transformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||||
elif encoder_type.lower() == 'residual_conv_bn':
|
elif encoder_type.lower() == 'residual_conv_bn':
|
||||||
self.pre = nn.Sequential(
|
self.pre = nn.Sequential(
|
||||||
nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
|
nn.Conv1d(self.in_channels, self.hidden_channels, 1),
|
||||||
self.encoder = ResidualConvBNBlock(hidden_channels,
|
nn.ReLU())
|
||||||
|
self.encoder = ResidualConvBNBlock(self.hidden_channels,
|
||||||
**encoder_params)
|
**encoder_params)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(' [!] encoder type not implemented.')
|
raise NotImplementedError(' [!] encoder type not implemented.')
|
||||||
|
|
||||||
# final projection layers
|
# final projection layers
|
||||||
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels,
|
||||||
self.post_bn = nn.BatchNorm1d(hidden_channels)
|
1)
|
||||||
self.post_conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
|
self.post_bn = nn.BatchNorm1d(self.hidden_channels)
|
||||||
|
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||||
# TODO: implement multi-speaker
|
# TODO: implement multi-speaker
|
||||||
|
|
|
@ -126,6 +126,7 @@ class GlowTts(nn.Module):
|
||||||
x_lenghts: B
|
x_lenghts: B
|
||||||
y: B x C x T
|
y: B x C x T
|
||||||
y_lengths: B
|
y_lengths: B
|
||||||
|
g: B x C or B
|
||||||
"""
|
"""
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
|
@ -133,7 +134,7 @@ class GlowTts(nn.Module):
|
||||||
if self.external_speaker_embedding_dim:
|
if self.external_speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||||
|
|
Loading…
Reference in New Issue