mirror of https://github.com/coqui-ai/TTS.git
more docstrings
This commit is contained in:
parent
6e9043c5d2
commit
b206162d11
|
@ -15,10 +15,11 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||||
class WN(torch.nn.Module):
|
class WN(torch.nn.Module):
|
||||||
"""Wavenet layers with weight norm and no input conditioning.
|
"""Wavenet layers with weight norm and no input conditioning.
|
||||||
|
|
||||||
x -> conv1d(dilation) -> dropout -> split(z + g) -> tanh(z1) -> t * s -> conv1d1x1 -> z + x -> ...
|
|-----------------------------------------------------------------------------|
|
||||||
g - - - - - - - - - - - - - - - - - - - - - ^ -> sigmoid(z2) - - - ^ |
|
| |-> tanh -| |
|
||||||
˅
|
res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res
|
||||||
o + z
|
g -------------------------------------| |-> sigmoid -| |
|
||||||
|
o --------------------------------------------------------------------------- + --------- o
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): number of input channels.
|
in_channels (int): number of input channels.
|
||||||
|
@ -55,12 +56,13 @@ class WN(torch.nn.Module):
|
||||||
self.res_skip_layers = torch.nn.ModuleList()
|
self.res_skip_layers = torch.nn.ModuleList()
|
||||||
self.dropout = nn.Dropout(dropout_p)
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
|
# init conditioning layer
|
||||||
if c_in_channels > 0:
|
if c_in_channels > 0:
|
||||||
cond_layer = torch.nn.Conv1d(c_in_channels,
|
cond_layer = torch.nn.Conv1d(c_in_channels,
|
||||||
2 * hidden_channels * num_layers, 1)
|
2 * hidden_channels * num_layers, 1)
|
||||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
|
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
|
||||||
name='weight')
|
name='weight')
|
||||||
|
# intermediate layers
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
dilation = dilation_rate**i
|
dilation = dilation_rate**i
|
||||||
padding = int((kernel_size * dilation - dilation) / 2)
|
padding = int((kernel_size * dilation - dilation) / 2)
|
||||||
|
@ -124,7 +126,23 @@ class WN(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class WNBlocks(nn.Module):
|
class WNBlocks(nn.Module):
|
||||||
"""Wavenet blocks"""
|
"""Wavenet blocks.
|
||||||
|
|
||||||
|
Note: After each block dilation resets to 1 and it increases in each block
|
||||||
|
along the dilation rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels.
|
||||||
|
hidden_channes (int): number of hidden channels.
|
||||||
|
kernel_size (int): filter kernel size for the first conv layer.
|
||||||
|
dilation_rate (int): dilations rate to increase dilation per layer.
|
||||||
|
If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers.
|
||||||
|
num_blocks (int): number of wavenet blocks.
|
||||||
|
num_layers (int): number of wavenet layers.
|
||||||
|
c_in_channels (int): number of channels of conditioning input.
|
||||||
|
dropout_p (float): dropout rate.
|
||||||
|
weight_norm (bool): enable/disable weight norm for convolution layers.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
|
|
@ -6,14 +6,23 @@ from TTS.tts.layers.generic.wavenet import WN
|
||||||
from ..generic.normalization import LayerNorm
|
from ..generic.normalization import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
class ConvLayerNorm(nn.Module):
|
class ResidualConv1dLayerNormBlock(nn.Module):
|
||||||
"""Residual Convolution with LayerNorm
|
|
||||||
x -> conv1d -> layer_norm -> dropout -> + -> o
|
|
||||||
| |
|
|
||||||
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
|
|
||||||
"""
|
|
||||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
||||||
num_layers, dropout_p):
|
num_layers, dropout_p):
|
||||||
|
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
|
||||||
|
https://arxiv.org/pdf/1811.00002.pdf
|
||||||
|
|
||||||
|
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||||
|
|---------------> conv1d_1x1 -----------------------|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input tensor channels.
|
||||||
|
hidden_channels (int): number of inner layer channels.
|
||||||
|
out_channels (int): number of output tensor channels.
|
||||||
|
kernel_size (int): kernel size of conv1d filter.
|
||||||
|
num_layers (int): number of blocks.
|
||||||
|
dropout_p (float): dropout rate for each block.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
|
@ -50,7 +59,9 @@ class ConvLayerNorm(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class InvConvNear(nn.Module):
|
class InvConvNear(nn.Module):
|
||||||
"""Inversible Convolution with splitting.
|
"""Invertible Convolution with input splitting as in GlowTTS paper.
|
||||||
|
https://arxiv.org/pdf/1811.00002.pdf
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels (int): input and output channels.
|
channels (int): input and output channels.
|
||||||
num_splits (int): number of splits, also H and W of conv layer.
|
num_splits (int): number of splits, also H and W of conv layer.
|
||||||
|
@ -122,8 +133,8 @@ class InvConvNear(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class CouplingBlock(nn.Module):
|
class CouplingBlock(nn.Module):
|
||||||
"""Glow Affine Coupling block.
|
"""Glow Affine Coupling block as in GlowTTS paper.
|
||||||
For details https://arxiv.org/pdf/1811.00002.pdf
|
https://arxiv.org/pdf/1811.00002.pdf
|
||||||
|
|
||||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||||
|
|
|
@ -311,6 +311,8 @@ class RelativePositionTransformer(nn.Module):
|
||||||
https://arxiv.org/abs/1803.02155
|
https://arxiv.org/abs/1803.02155
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
in_channels (int): number of channels of the input tensor.
|
||||||
|
out_chanels (int): number of channels of the output tensor.
|
||||||
hidden_channels (int): model hidden channels.
|
hidden_channels (int): model hidden channels.
|
||||||
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
|
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
|
||||||
num_heads (int): number of attention heads.
|
num_heads (int): number of attention heads.
|
||||||
|
|
Loading…
Reference in New Issue