small fixes

This commit is contained in:
erogol 2020-12-30 14:20:08 +01:00
parent 5901a00576
commit eb555855e4
2 changed files with 15 additions and 12 deletions

View File

@ -15,7 +15,6 @@ class PositionalEncoding(nn.Module):
dropout (float): dropout parameter dropout (float): dropout parameter
dim (int): embedding size dim (int): embedding size
""" """
def __init__(self, dim, dropout=0.0, max_len=5000): def __init__(self, dim, dropout=0.0, max_len=5000):
super().__init__() super().__init__()
if dim % 2 != 0: if dim % 2 != 0:
@ -117,33 +116,36 @@ class Encoder(nn.Module):
""" """
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.in_channels = in_hidden_channels
self.hidden_channels = in_hidden_channels
self.encoder_type = encoder_type self.encoder_type = encoder_type
self.c_in_channels = c_in_channels self.c_in_channels = c_in_channels
# init encoder # init encoder
if encoder_type.lower() == "transformer": if encoder_type.lower() == "transformer":
# optional convolutional prenet # optional convolutional prenet
self.pre = ConvLayerNorm(hidden_channels, self.pre = ConvLayerNorm(self.in_channels,
hidden_channels, self.hidden_channels,
hidden_channels, self.hidden_channels,
kernel_size=5, kernel_size=5,
num_layers=3, num_layers=3,
dropout_p=0.5) dropout_p=0.5)
# text encoder # text encoder
self.encoder = Transformer(hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg self.encoder = Transformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
elif encoder_type.lower() == 'residual_conv_bn': elif encoder_type.lower() == 'residual_conv_bn':
self.pre = nn.Sequential( self.pre = nn.Sequential(
nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) nn.Conv1d(self.in_channels, self.hidden_channels, 1),
self.encoder = ResidualConvBNBlock(hidden_channels, nn.ReLU())
self.encoder = ResidualConvBNBlock(self.hidden_channels,
**encoder_params) **encoder_params)
else: else:
raise NotImplementedError(' [!] encoder type not implemented.') raise NotImplementedError(' [!] encoder type not implemented.')
# final projection layers # final projection layers
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels,
self.post_bn = nn.BatchNorm1d(hidden_channels) 1)
self.post_conv2 = nn.Conv1d(hidden_channels, out_channels, 1) self.post_bn = nn.BatchNorm1d(self.hidden_channels)
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
# TODO: implement multi-speaker # TODO: implement multi-speaker

View File

@ -126,6 +126,7 @@ class GlowTts(nn.Module):
x_lenghts: B x_lenghts: B
y: B x C x T y: B x C x T
y_lengths: B y_lengths: B
g: B x C or B
""" """
y_max_length = y.size(2) y_max_length = y.size(2)
# norm speaker embeddings # norm speaker embeddings
@ -133,7 +134,7 @@ class GlowTts(nn.Module):
if self.external_speaker_embedding_dim: if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1) g = F.normalize(g).unsqueeze(-1)
else: else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h] g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,