mirror of https://github.com/coqui-ai/TTS.git
glow-tts comments and refactoring
This commit is contained in:
parent
7586fbc4de
commit
d3b7284be4
|
@ -7,6 +7,11 @@ from ..generic.normalization import LayerNorm
|
|||
|
||||
|
||||
class ConvLayerNorm(nn.Module):
|
||||
"""Residual Convolution with LayerNorm
|
||||
x -> conv1d -> layer_norm -> dropout -> + -> o
|
||||
| |
|
||||
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
||||
num_layers, dropout_p):
|
||||
super().__init__()
|
||||
|
@ -22,16 +27,9 @@ class ConvLayerNorm(nn.Module):
|
|||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
|
||||
for _ in range(num_layers - 1):
|
||||
for idx in range(num_layers - 1):
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(hidden_channels,
|
||||
nn.Conv1d(in_channels if idx == 0 else hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2))
|
||||
|
@ -52,6 +50,17 @@ class ConvLayerNorm(nn.Module):
|
|||
|
||||
|
||||
class InvConvNear(nn.Module):
|
||||
"""Inversible Convolution with splitting.
|
||||
Args:
|
||||
channels (int): input and output channels.
|
||||
num_splits (int): number of splits, also H and W of conv layer.
|
||||
no_jacobian (bool): enable/disable jacobian computations.
|
||||
|
||||
Note:
|
||||
Split the input into groups of size self.num_splits and
|
||||
perform 1x1 convolution separately. Cast 1x1 conv operation
|
||||
to 2d by reshaping the input for efficiency.
|
||||
"""
|
||||
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
assert num_splits % 2 == 0
|
||||
|
@ -67,9 +76,10 @@ class InvConvNear(nn.Module):
|
|||
self.weight = nn.Parameter(w_init)
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
"""Split the input into groups of size self.num_splits and
|
||||
perform 1x1 convolution separately. Cast 1x1 conv operation
|
||||
to 2d by reshaping the input for efficiency.
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
x_mask: B x 1 x T
|
||||
"""
|
||||
|
||||
b, c, t = x.size()
|
||||
|
@ -112,6 +122,25 @@ class InvConvNear(nn.Module):
|
|||
|
||||
|
||||
class CouplingBlock(nn.Module):
|
||||
"""Glow Affine Coupling block.
|
||||
For details https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of hidden channels.
|
||||
kernel_size (int): WaveNet filter kernel size.
|
||||
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||
num_layers (int): number of WaveNet layers.
|
||||
c_in_channels (int): number of conditioning input channels.
|
||||
dropout_p (int): wavenet dropout rate.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
|
||||
|
||||
Note:
|
||||
It does not use conditional inputs differently from WaveGlow.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
|
@ -130,21 +159,28 @@ class CouplingBlock(nn.Module):
|
|||
self.c_in_channels = c_in_channels
|
||||
self.dropout_p = dropout_p
|
||||
self.sigmoid_scale = sigmoid_scale
|
||||
|
||||
# input layer
|
||||
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
||||
start = torch.nn.utils.weight_norm(start)
|
||||
self.start = start
|
||||
# output layer
|
||||
# Initializing last layer to 0 makes the affine coupling layers
|
||||
# do nothing at first. This helps with training stability
|
||||
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
||||
end.weight.data.zero_()
|
||||
end.bias.data.zero_()
|
||||
self.end = end
|
||||
|
||||
# coupling layers
|
||||
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate,
|
||||
num_layers, c_in_channels, dropout_p)
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
x_mask: B x 1 x T
|
||||
g: B x C x 1
|
||||
"""
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
||||
|
@ -154,17 +190,17 @@ class CouplingBlock(nn.Module):
|
|||
out = self.end(x)
|
||||
|
||||
z_0 = x_0
|
||||
m = out[:, :self.in_channels // 2, :]
|
||||
logs = out[:, self.in_channels // 2:, :]
|
||||
t = out[:, :self.in_channels // 2, :]
|
||||
s = out[:, self.in_channels // 2:, :]
|
||||
if self.sigmoid_scale:
|
||||
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
||||
s = torch.log(1e-6 + torch.sigmoid(s + 2))
|
||||
|
||||
if reverse:
|
||||
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
||||
z_1 = (x_1 - t) * torch.exp(-s) * x_mask
|
||||
logdet = None
|
||||
else:
|
||||
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
||||
logdet = torch.sum(logs * x_mask, [1, 2])
|
||||
z_1 = (t + torch.exp(s) * x_1) * x_mask
|
||||
logdet = torch.sum(s * x_mask, [1, 2])
|
||||
|
||||
z = torch.cat([z_0, z_1], 1)
|
||||
return z, logdet
|
||||
|
|
Loading…
Reference in New Issue