From 1b238f04b2c7b018adb9699528d1f0c9353ab4b0 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 17 Aug 2020 12:34:51 +0200 Subject: [PATCH] add gated conv encoder to glow-tts --- TTS/tts/utils/generic_utils.py | 1 + .../tts/layers/glow_tts/encoder.py | 98 ++++++++++++++----- mozilla_voice_tts/tts/models/glow_tts.py | 4 +- 3 files changed, 78 insertions(+), 25 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 6d7ceb75..e93a14f7 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -110,6 +110,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): kernel_size=3, num_heads=2, num_layers_enc=6, + encoder_type=c.encoder_type, dropout_p=0.1, num_flow_blocks_dec=12, kernel_size_dec=5, diff --git a/mozilla_voice_tts/tts/layers/glow_tts/encoder.py b/mozilla_voice_tts/tts/layers/glow_tts/encoder.py index b0e40ca4..c44116ad 100644 --- a/mozilla_voice_tts/tts/layers/glow_tts/encoder.py +++ b/mozilla_voice_tts/tts/layers/glow_tts/encoder.py @@ -4,14 +4,53 @@ from torch import nn from mozilla_voice_tts.tts.layers.glow_tts.transformer import Transformer from mozilla_voice_tts.tts.utils.generic_utils import sequence_mask -from mozilla_voice_tts.tts.layers.glow_tts.glow import ConvLayerNorm +from mozilla_voice_tts.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm from mozilla_voice_tts.tts.layers.glow_tts.duration_predictor import DurationPredictor +class GatedConvBlock(nn.Module): + """Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf + Args: + in_out_channels (int): number of input/output channels. + kernel_size (int): convolution kernel size. + dropout_p (float): dropout rate. + """ + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): + super().__init__() + # class arguments + self.dropout_p = dropout_p + self.num_layers = num_layers + # define layers + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.conv_layers += [ + nn.Conv1d(in_out_channels, + 2 * in_out_channels, + kernel_size, + padding=kernel_size // 2) + ] + self.norm_layers += [LayerNorm(2 * in_out_channels)] + + def forward(self, x, x_mask): + o = x + res = x + for idx in range(self.num_layers): + o = nn.functional.dropout(o, + p=self.dropout_p, + training=self.training) + o = self.conv_layers[idx](o * x_mask) + o = self.norm_layers[idx](o) + o = nn.functional.glu(o, dim=1) + o = res + o + res = o + return o + + class Encoder(nn.Module): - """Glow-TTS encoder module. We use Pytorch TransformerEncoder instead - of the one with relative position embedding. We use positional encoding - for capturing positiong information. + """Glow-TTS encoder module. It uses Transformer with Relative Pos.Encoding + as in the original paper or GatedConvBlock as a faster alternative. Args: num_chars (int): number of characters. @@ -29,13 +68,13 @@ class Encoder(nn.Module): Shapes: - input: (B, T, C) """ - def __init__(self, num_chars, out_channels, hidden_channels, filter_channels, filter_channels_dp, + encoder_type, num_heads, num_layers, kernel_size, @@ -59,26 +98,36 @@ class Encoder(nn.Module): self.mean_only = mean_only self.use_prenet = use_prenet self.c_in_channels = c_in_channels + self.encoder_type = encoder_type # embedding layer self.emb = nn.Embedding(num_chars, hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - # optional convolutional prenet - if use_prenet: - self.pre = ConvLayerNorm(hidden_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) - # text encoder - self.encoder = Transformer(hidden_channels, - filter_channels, - num_heads, - num_layers, - kernel_size=kernel_size, - dropout_p=dropout_p, - rel_attn_window_size=rel_attn_window_size, - input_length=input_length) + # init encoder + if encoder_type.lower() == "transformer": + # optional convolutional prenet + if use_prenet: + self.pre = ConvLayerNorm(hidden_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_layers=3, + dropout_p=0.5) + # text encoder + self.encoder = Transformer( + hidden_channels, + filter_channels, + num_heads, + num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + rel_attn_window_size=rel_attn_window_size, + input_length=input_length) + elif encoder_type.lower() == 'gatedconv': + breakpoint() + self.encoder = GatedConvBlock(hidden_channels, + kernel_size=5, + dropout_p=dropout_p, + num_layers=3 + num_layers) # final projection layers self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) if not mean_only: @@ -98,8 +147,9 @@ class Encoder(nn.Module): x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # pre-conv layers - if self.use_prenet: - x = self.pre(x, x_mask) + if self.encoder_type == 'transformer': + if self.use_prenet: + x = self.pre(x, x_mask) # encoder x = self.encoder(x, x_mask) # set duration predictor input diff --git a/mozilla_voice_tts/tts/models/glow_tts.py b/mozilla_voice_tts/tts/models/glow_tts.py index 9bb96ae4..b7551086 100644 --- a/mozilla_voice_tts/tts/models/glow_tts.py +++ b/mozilla_voice_tts/tts/models/glow_tts.py @@ -36,7 +36,8 @@ class GlowTts(nn.Module): mean_only=False, hidden_channels_enc=None, hidden_channels_dec=None, - use_encoder_prenet=False): + use_encoder_prenet=False, + encoder_type="transformer"): super().__init__() self.num_chars = num_chars @@ -72,6 +73,7 @@ class GlowTts(nn.Module): hidden_channels_enc or hidden_channels, filter_channels, filter_channels_dp, + encoder_type, num_heads, num_layers_enc, kernel_size,