diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index 1c4bf99b..d0a1f3e9 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -15,10 +15,11 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): class WN(torch.nn.Module): """Wavenet layers with weight norm and no input conditioning. - x -> conv1d(dilation) -> dropout -> split(z + g) -> tanh(z1) -> t * s -> conv1d1x1 -> z + x -> ... - g - - - - - - - - - - - - - - - - - - - - - ^ -> sigmoid(z2) - - - ^ | - ˅ - o + z + |-----------------------------------------------------------------------------| + | |-> tanh -| | + res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res + g -------------------------------------| |-> sigmoid -| | + o --------------------------------------------------------------------------- + --------- o Args: in_channels (int): number of input channels. @@ -55,12 +56,13 @@ class WN(torch.nn.Module): self.res_skip_layers = torch.nn.ModuleList() self.dropout = nn.Dropout(dropout_p) + # init conditioning layer if c_in_channels > 0: cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') - + # intermediate layers for i in range(num_layers): dilation = dilation_rate**i padding = int((kernel_size * dilation - dilation) / 2) @@ -124,7 +126,23 @@ class WN(torch.nn.Module): class WNBlocks(nn.Module): - """Wavenet blocks""" + """Wavenet blocks. + + Note: After each block dilation resets to 1 and it increases in each block + along the dilation rate. + + Args: + in_channels (int): number of input channels. + hidden_channes (int): number of hidden channels. + kernel_size (int): filter kernel size for the first conv layer. + dilation_rate (int): dilations rate to increase dilation per layer. + If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers. + num_blocks (int): number of wavenet blocks. + num_layers (int): number of wavenet layers. + c_in_channels (int): number of channels of conditioning input. + dropout_p (float): dropout rate. + weight_norm (bool): enable/disable weight norm for convolution layers. + """ def __init__(self, in_channels, diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index 07826fae..f6385747 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -6,14 +6,23 @@ from TTS.tts.layers.generic.wavenet import WN from ..generic.normalization import LayerNorm -class ConvLayerNorm(nn.Module): - """Residual Convolution with LayerNorm - x -> conv1d -> layer_norm -> dropout -> + -> o - | | - ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯ - """ +class ResidualConv1dLayerNormBlock(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): + """Conv1d with Layer Normalization and residual connection as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o + |---------------> conv1d_1x1 -----------------------| + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of inner layer channels. + out_channels (int): number of output tensor channels. + kernel_size (int): kernel size of conv1d filter. + num_layers (int): number of blocks. + dropout_p (float): dropout rate for each block. + """ super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels @@ -50,7 +59,9 @@ class ConvLayerNorm(nn.Module): class InvConvNear(nn.Module): - """Inversible Convolution with splitting. + """Invertible Convolution with input splitting as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + Args: channels (int): input and output channels. num_splits (int): number of splits, also H and W of conv layer. @@ -122,8 +133,8 @@ class InvConvNear(nn.Module): class CouplingBlock(nn.Module): - """Glow Affine Coupling block. - For details https://arxiv.org/pdf/1811.00002.pdf + """Glow Affine Coupling block as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index b3a045e5..a872a175 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -311,6 +311,8 @@ class RelativePositionTransformer(nn.Module): https://arxiv.org/abs/1803.02155 Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. hidden_channels (int): model hidden channels. hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. num_heads (int): number of attention heads.