mirror of https://github.com/coqui-ai/TTS.git
more docstrings
This commit is contained in:
parent
6e9043c5d2
commit
b206162d11
|
@ -15,10 +15,11 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
|||
class WN(torch.nn.Module):
|
||||
"""Wavenet layers with weight norm and no input conditioning.
|
||||
|
||||
x -> conv1d(dilation) -> dropout -> split(z + g) -> tanh(z1) -> t * s -> conv1d1x1 -> z + x -> ...
|
||||
g - - - - - - - - - - - - - - - - - - - - - ^ -> sigmoid(z2) - - - ^ |
|
||||
˅
|
||||
o + z
|
||||
|-----------------------------------------------------------------------------|
|
||||
| |-> tanh -| |
|
||||
res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res
|
||||
g -------------------------------------| |-> sigmoid -| |
|
||||
o --------------------------------------------------------------------------- + --------- o
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
|
@ -55,12 +56,13 @@ class WN(torch.nn.Module):
|
|||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
# init conditioning layer
|
||||
if c_in_channels > 0:
|
||||
cond_layer = torch.nn.Conv1d(c_in_channels,
|
||||
2 * hidden_channels * num_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
|
||||
name='weight')
|
||||
|
||||
# intermediate layers
|
||||
for i in range(num_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
|
@ -124,7 +126,23 @@ class WN(torch.nn.Module):
|
|||
|
||||
|
||||
class WNBlocks(nn.Module):
|
||||
"""Wavenet blocks"""
|
||||
"""Wavenet blocks.
|
||||
|
||||
Note: After each block dilation resets to 1 and it increases in each block
|
||||
along the dilation rate.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
hidden_channes (int): number of hidden channels.
|
||||
kernel_size (int): filter kernel size for the first conv layer.
|
||||
dilation_rate (int): dilations rate to increase dilation per layer.
|
||||
If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers.
|
||||
num_blocks (int): number of wavenet blocks.
|
||||
num_layers (int): number of wavenet layers.
|
||||
c_in_channels (int): number of channels of conditioning input.
|
||||
dropout_p (float): dropout rate.
|
||||
weight_norm (bool): enable/disable weight norm for convolution layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
|
|
|
@ -6,14 +6,23 @@ from TTS.tts.layers.generic.wavenet import WN
|
|||
from ..generic.normalization import LayerNorm
|
||||
|
||||
|
||||
class ConvLayerNorm(nn.Module):
|
||||
"""Residual Convolution with LayerNorm
|
||||
x -> conv1d -> layer_norm -> dropout -> + -> o
|
||||
| |
|
||||
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
|
||||
"""
|
||||
class ResidualConv1dLayerNormBlock(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
||||
num_layers, dropout_p):
|
||||
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||
|---------------> conv1d_1x1 -----------------------|
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of inner layer channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
kernel_size (int): kernel size of conv1d filter.
|
||||
num_layers (int): number of blocks.
|
||||
dropout_p (float): dropout rate for each block.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
|
@ -50,7 +59,9 @@ class ConvLayerNorm(nn.Module):
|
|||
|
||||
|
||||
class InvConvNear(nn.Module):
|
||||
"""Inversible Convolution with splitting.
|
||||
"""Invertible Convolution with input splitting as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
Args:
|
||||
channels (int): input and output channels.
|
||||
num_splits (int): number of splits, also H and W of conv layer.
|
||||
|
@ -122,8 +133,8 @@ class InvConvNear(nn.Module):
|
|||
|
||||
|
||||
class CouplingBlock(nn.Module):
|
||||
"""Glow Affine Coupling block.
|
||||
For details https://arxiv.org/pdf/1811.00002.pdf
|
||||
"""Glow Affine Coupling block as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
|
|
|
@ -311,6 +311,8 @@ class RelativePositionTransformer(nn.Module):
|
|||
https://arxiv.org/abs/1803.02155
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels of the input tensor.
|
||||
out_chanels (int): number of channels of the output tensor.
|
||||
hidden_channels (int): model hidden channels.
|
||||
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
|
||||
num_heads (int): number of attention heads.
|
||||
|
|
Loading…
Reference in New Issue