From 1d961d6f8a134c5bbb40ffeb6d79b6929737474c Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 11 Jan 2021 17:26:11 +0100 Subject: [PATCH] cladd renaming --- TTS/tts/layers/glow_tts/encoder.py | 19 ++++++++++++------- TTS/tts/layers/glow_tts/transformer.py | 22 ++++++++++++++++------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index ab7aaba5..9a1508ee 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -5,10 +5,10 @@ from torch import nn from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.generic.gated_conv import GatedConvBlock from TTS.tts.utils.generic_utils import sequence_mask -from TTS.tts.layers.glow_tts.glow import ConvLayerNorm +from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock -from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock class Encoder(nn.Module): @@ -97,14 +97,16 @@ class Encoder(nn.Module): # init encoder module if encoder_type.lower() == "rel_pos_transformer": if use_prenet: - self.prenet = ConvLayerNorm(hidden_channels, + self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5) - self.encoder = RelativePositionTransformer( - hidden_channels, **encoder_params) + self.encoder = RelativePositionTransformer(hidden_channels, + hidden_channels, + hidden_channels, + **encoder_params) elif encoder_type.lower() == 'gated_conv': self.encoder = GatedConvBlock(hidden_channels, **encoder_params) elif encoder_type.lower() == 'residual_conv_bn': @@ -113,13 +115,16 @@ class Encoder(nn.Module): nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU() ) - self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params) + self.encoder = ResidualConv1dBNBlock(hidden_channels, + hidden_channels, + hidden_channels, + **encoder_params) self.postnet = nn.Sequential( nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels)) elif encoder_type.lower() == 'time_depth_separable': if use_prenet: - self.prenet = ConvLayerNorm(hidden_channels, + self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, hidden_channels, hidden_channels, kernel_size=5, diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index a872a175..81160006 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -264,7 +264,7 @@ class RelativePositionMultiHeadAttention(nn.Module): return diff.unsqueeze(0).unsqueeze(0) -class FFN(nn.Module): +class FeedForwardNetwork(nn.Module): """Feed Forward Inner layers for Transformer. Args: @@ -326,6 +326,8 @@ class RelativePositionTransformer(nn.Module): input_length (int, optional): input lenght to limit position encoding. Defaults to None. """ def __init__(self, + in_channels, + out_channels, hidden_channels, hidden_channels_ffn, num_heads, @@ -348,23 +350,31 @@ class RelativePositionTransformer(nn.Module): self.norm_layers_1 = nn.ModuleList() self.ffn_layers = nn.ModuleList() self.norm_layers_2 = nn.ModuleList() - for _ in range(self.num_layers): + + for idx in range(self.num_layers): self.attn_layers.append( RelativePositionMultiHeadAttention( - hidden_channels, + hidden_channels if idx != 0 else in_channels, hidden_channels, num_heads, rel_attn_window_size=rel_attn_window_size, dropout_p=dropout_p, input_length=input_length)) self.norm_layers_1.append(LayerNorm(hidden_channels)) + + if hidden_channels != out_channels and (idx + 1) == self.num_layers: + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.ffn_layers.append( - FFN(hidden_channels, - hidden_channels, + FeedForwardNetwork(hidden_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, hidden_channels_ffn, kernel_size, dropout_p=dropout_p)) - self.norm_layers_2.append(LayerNorm(hidden_channels)) + + self.norm_layers_2.append( + LayerNorm(hidden_channels if ( + idx + 1) != self.num_layers else out_channels)) def forward(self, x, x_mask): """