more docstrings

This commit is contained in:
erogol 2021-01-11 17:25:04 +01:00
parent 6e9043c5d2
commit b206162d11
3 changed files with 46 additions and 15 deletions

View File

@ -15,10 +15,11 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
class WN(torch.nn.Module):
"""Wavenet layers with weight norm and no input conditioning.
x -> conv1d(dilation) -> dropout -> split(z + g) -> tanh(z1) -> t * s -> conv1d1x1 -> z + x -> ...
g - - - - - - - - - - - - - - - - - - - - - ^ -> sigmoid(z2) - - - ^ |
˅
o + z
|-----------------------------------------------------------------------------|
| |-> tanh -| |
res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res
g -------------------------------------| |-> sigmoid -| |
o --------------------------------------------------------------------------- + --------- o
Args:
in_channels (int): number of input channels.
@ -55,12 +56,13 @@ class WN(torch.nn.Module):
self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(dropout_p)
# init conditioning layer
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels,
2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
name='weight')
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
@ -124,7 +126,23 @@ class WN(torch.nn.Module):
class WNBlocks(nn.Module):
"""Wavenet blocks"""
"""Wavenet blocks.
Note: After each block dilation resets to 1 and it increases in each block
along the dilation rate.
Args:
in_channels (int): number of input channels.
hidden_channes (int): number of hidden channels.
kernel_size (int): filter kernel size for the first conv layer.
dilation_rate (int): dilations rate to increase dilation per layer.
If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers.
num_blocks (int): number of wavenet blocks.
num_layers (int): number of wavenet layers.
c_in_channels (int): number of channels of conditioning input.
dropout_p (float): dropout rate.
weight_norm (bool): enable/disable weight norm for convolution layers.
"""
def __init__(self,
in_channels,

View File

@ -6,14 +6,23 @@ from TTS.tts.layers.generic.wavenet import WN
from ..generic.normalization import LayerNorm
class ConvLayerNorm(nn.Module):
"""Residual Convolution with LayerNorm
x -> conv1d -> layer_norm -> dropout -> + -> o
| |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
"""
class ResidualConv1dLayerNormBlock(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
num_layers, dropout_p):
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------|
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of inner layer channels.
out_channels (int): number of output tensor channels.
kernel_size (int): kernel size of conv1d filter.
num_layers (int): number of blocks.
dropout_p (float): dropout rate for each block.
"""
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
@ -50,7 +59,9 @@ class ConvLayerNorm(nn.Module):
class InvConvNear(nn.Module):
"""Inversible Convolution with splitting.
"""Invertible Convolution with input splitting as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
Args:
channels (int): input and output channels.
num_splits (int): number of splits, also H and W of conv layer.
@ -122,8 +133,8 @@ class InvConvNear(nn.Module):
class CouplingBlock(nn.Module):
"""Glow Affine Coupling block.
For details https://arxiv.org/pdf/1811.00002.pdf
"""Glow Affine Coupling block as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^

View File

@ -311,6 +311,8 @@ class RelativePositionTransformer(nn.Module):
https://arxiv.org/abs/1803.02155
Args:
in_channels (int): number of channels of the input tensor.
out_chanels (int): number of channels of the output tensor.
hidden_channels (int): model hidden channels.
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
num_heads (int): number of attention heads.