From c2d29e5cd42bdce8e94170503b1fdf28b36692c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Mar 2021 17:05:52 +0100 Subject: [PATCH] FFTransformer encoder for aligntts --- TTS/tts/layers/feed_forward/encoder.py | 72 +++++--------------------- 1 file changed, 14 insertions(+), 58 deletions(-) diff --git a/TTS/tts/layers/feed_forward/encoder.py b/TTS/tts/layers/feed_forward/encoder.py index 8086286c..3edf339d 100644 --- a/TTS/tts/layers/feed_forward/encoder.py +++ b/TTS/tts/layers/feed_forward/encoder.py @@ -4,59 +4,7 @@ from torch import nn from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock - - - -class PositionalEncoding(nn.Module): - """Sinusoidal positional encoding for non-recurrent neural networks. - Implementation based on "Attention Is All You Need" - Args: - channels (int): embedding size - dropout (float): dropout parameter - """ - def __init__(self, channels, dropout=0.0, max_len=5000): - super().__init__() - if channels % 2 != 0: - raise ValueError( - "Cannot use sin/cos positional encoding with " - "odd channels (got channels={:d})".format(channels)) - pe = torch.zeros(max_len, channels) - position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) * - -(math.log(10000.0) / channels))) - pe[:, 0::2] = torch.sin(position.float() * div_term) - pe[:, 1::2] = torch.cos(position.float() * div_term) - pe = pe.unsqueeze(0).transpose(1, 2) - self.register_buffer('pe', pe) - if dropout > 0: - self.dropout = nn.Dropout(p=dropout) - self.channels = channels - - def forward(self, x, mask=None, first_idx=None, last_idx=None): - """ - Shapes: - x: [B, C, T] - mask: [B, 1, T] - first_idx: int - last_idx: int - """ - - x = x * math.sqrt(self.channels) - if first_idx is None: - if self.pe.size(2) < x.size(2): - raise RuntimeError( - f"Sequence is {x.size(2)} but PositionalEncoding is" - f" limited to {self.pe.size(2)}. See max_len argument.") - if mask is not None: - pos_enc = (self.pe[:, :, :x.size(2)] * mask) - else: - pos_enc = self.pe[:, :, :x.size(2)] - x = x + pos_enc - else: - x = x + self.pe[:, :, first_idx:last_idx] - if hasattr(self, 'dropout'): - x = self.dropout(x) - return x +from TTS.tts.layers.generic.transformer import FFTransformersBlock class RelativePositionTransformerEncoder(nn.Module): @@ -138,9 +86,9 @@ class Encoder(nn.Module): c_in_channels (int): number of channels for conditional input. Note: - Default encoder_params... + Default encoder_params to be set in config.json... - for 'transformer' + for 'relative_position_transformer' encoder_params={ 'hidden_channels_ffn': 128, 'num_heads': 2, @@ -158,6 +106,14 @@ class Encoder(nn.Module): "num_conv_blocks": 2, "num_res_blocks": 13 } + + for 'transformer_decoder' + encoder_params = { + hidden_channels_ffn: 1024 , + num_heads: 2, + num_layers: 6, + dropout_p: 0.1 + } """ def __init__( self, @@ -179,7 +135,7 @@ class Encoder(nn.Module): self.c_in_channels = c_in_channels # init encoder - if encoder_type.lower() == "transformer": + if encoder_type.lower() == "relative_position_transformer": # text encoder self.encoder = RelativePositionTransformerEncoder( in_hidden_channels, out_channels, in_hidden_channels, @@ -189,11 +145,11 @@ class Encoder(nn.Module): out_channels, in_hidden_channels, encoder_params) + elif encoder_type.lower() == 'transformer': + self.encoder = FFTransformersBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg else: raise NotImplementedError(' [!] unknown encoder type.') - # final projection layers - def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """