update glow-tts parameters and fix rel-attn-win size

This commit is contained in:
erogol 2020-12-17 12:27:51 +01:00
parent 7b20d8cbd3
commit fa6907fa0e
5 changed files with 33 additions and 36 deletions

View File

@ -5,27 +5,27 @@ from ..generic.normalization import LayerNorm
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, dropout_p):
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
super().__init__()
# class arguments
self.in_channels = in_channels
self.filter_channels = filter_channels
self.filter_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_p = dropout_p
# layers
self.drop = nn.Dropout(dropout_p)
self.conv_1 = nn.Conv1d(in_channels,
filter_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels,
filter_channels,
self.norm_1 = LayerNorm(hidden_channels)
self.conv_2 = nn.Conv1d(hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.norm_2 = LayerNorm(hidden_channels)
# output layer
self.proj = nn.Conv1d(filter_channels, 1, 1)
self.proj = nn.Conv1d(hidden_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)

View File

@ -19,7 +19,7 @@ class Encoder(nn.Module):
num_chars (int): number of characters.
out_channels (int): number of output channels.
hidden_channels (int): encoder's embedding size.
filter_channels (int): transformer's feed-forward channels.
hidden_channels_ffn (int): transformer's feed-forward channels.
num_head (int): number of attention heads in transformer.
num_layers (int): number of transformer encoder stack.
kernel_size (int): kernel size for conv layers and duration predictor.
@ -35,12 +35,11 @@ class Encoder(nn.Module):
num_chars,
out_channels,
hidden_channels,
filter_channels,
filter_channels_dp,
hidden_channels_ffn,
hidden_channels_dp,
encoder_type,
num_heads,
num_layers,
kernel_size,
dropout_p,
rel_attn_window_size=None,
input_length=None,
@ -52,11 +51,10 @@ class Encoder(nn.Module):
self.num_chars = num_chars
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.filter_channels_dp = filter_channels_dp
self.hidden_channels_ffn = hidden_channels_ffn
self.hidden_channels_dp = hidden_channels_dp
self.num_heads = num_heads
self.num_layers = num_layers
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.mean_only = mean_only
self.use_prenet = use_prenet
@ -78,10 +76,10 @@ class Encoder(nn.Module):
# text encoder
self.encoder = Transformer(
hidden_channels,
filter_channels,
hidden_channels_ffn,
num_heads,
num_layers,
kernel_size=kernel_size,
kernel_size=3,
dropout_p=dropout_p,
rel_attn_window_size=rel_attn_window_size,
input_length=input_length)
@ -125,7 +123,7 @@ class Encoder(nn.Module):
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
# duration predictor
self.duration_predictor = DurationPredictor(
hidden_channels + c_in_channels, filter_channels_dp, kernel_size,
hidden_channels + c_in_channels, hidden_channels_dp, 3,
dropout_p)
def forward(self, x, x_lengths, g=None):

View File

@ -229,23 +229,23 @@ class FFN(nn.Module):
def __init__(self,
in_channels,
out_channels,
filter_channels,
hidden_channels,
kernel_size,
dropout_p=0.,
activation=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.activation = activation
self.conv_1 = nn.Conv1d(in_channels,
filter_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(filter_channels,
self.conv_2 = nn.Conv1d(hidden_channels,
out_channels,
kernel_size,
padding=kernel_size // 2)
@ -265,7 +265,7 @@ class FFN(nn.Module):
class Transformer(nn.Module):
def __init__(self,
hidden_channels,
filter_channels,
hidden_channels_ffn,
num_heads,
num_layers,
kernel_size=1,
@ -274,7 +274,7 @@ class Transformer(nn.Module):
input_length=None):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.hidden_channels_ffn = hidden_channels_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.kernel_size = kernel_size
@ -299,7 +299,7 @@ class Transformer(nn.Module):
self.ffn_layers.append(
FFN(hidden_channels,
hidden_channels,
filter_channels,
hidden_channels_ffn,
kernel_size,
dropout_p=dropout_p))
self.norm_layers_2.append(LayerNorm(hidden_channels))

View File

@ -14,10 +14,9 @@ class GlowTts(nn.Module):
def __init__(self,
num_chars,
hidden_channels,
filter_channels,
filter_channels_dp,
hidden_channels_ffn,
hidden_channels_dp,
out_channels,
kernel_size=3,
num_heads=2,
num_layers_enc=6,
dropout_p=0.1,
@ -43,10 +42,9 @@ class GlowTts(nn.Module):
super().__init__()
self.num_chars = num_chars
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.filter_channels_dp = filter_channels_dp
self.hidden_channels_ffn = hidden_channels_ffn
self.hidden_channels_dp = hidden_channels_dp
self.out_channels = out_channels
self.kernel_size = kernel_size
self.num_heads = num_heads
self.num_layers_enc = num_layers_enc
self.dropout_p = dropout_p
@ -80,13 +78,13 @@ class GlowTts(nn.Module):
self.encoder = Encoder(num_chars,
out_channels=out_channels,
hidden_channels=hidden_channels,
filter_channels=filter_channels,
filter_channels_dp=filter_channels_dp,
hidden_channels_ffn=hidden_channels_ffn,
hidden_channels_dp=hidden_channels_dp,
encoder_type=encoder_type,
num_heads=num_heads,
num_layers=num_layers_enc,
kernel_size=kernel_size,
dropout_p=dropout_p,
rel_attn_window_size=rel_attn_window_size,
mean_only=mean_only,
use_prenet=use_encoder_prenet,
c_in_channels=self.c_in_channels)

View File

@ -104,8 +104,8 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
elif c.model.lower() == "glow_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels=192,
filter_channels=768,
filter_channels_dp=256,
hidden_channels_ffn=768,
hidden_channels_dp=256,
out_channels=c.audio['num_mels'],
kernel_size=3,
num_heads=2,
@ -126,6 +126,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
hidden_channels_enc=192,
hidden_channels_dec=192,
use_encoder_prenet=True,
rel_attn_window_size=4,
external_speaker_embedding_dim=speaker_embedding_dim)
return model