mirror of https://github.com/coqui-ai/TTS.git
glow-tts comments and refactoring
This commit is contained in:
parent
7586fbc4de
commit
d3b7284be4
|
@ -7,6 +7,11 @@ from ..generic.normalization import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
class ConvLayerNorm(nn.Module):
|
class ConvLayerNorm(nn.Module):
|
||||||
|
"""Residual Convolution with LayerNorm
|
||||||
|
x -> conv1d -> layer_norm -> dropout -> + -> o
|
||||||
|
| |
|
||||||
|
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
|
||||||
|
"""
|
||||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
||||||
num_layers, dropout_p):
|
num_layers, dropout_p):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -22,16 +27,9 @@ class ConvLayerNorm(nn.Module):
|
||||||
self.conv_layers = nn.ModuleList()
|
self.conv_layers = nn.ModuleList()
|
||||||
self.norm_layers = nn.ModuleList()
|
self.norm_layers = nn.ModuleList()
|
||||||
|
|
||||||
|
for idx in range(num_layers - 1):
|
||||||
self.conv_layers.append(
|
self.conv_layers.append(
|
||||||
nn.Conv1d(in_channels,
|
nn.Conv1d(in_channels if idx == 0 else hidden_channels,
|
||||||
hidden_channels,
|
|
||||||
kernel_size,
|
|
||||||
padding=kernel_size // 2))
|
|
||||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
||||||
|
|
||||||
for _ in range(num_layers - 1):
|
|
||||||
self.conv_layers.append(
|
|
||||||
nn.Conv1d(hidden_channels,
|
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
padding=kernel_size // 2))
|
padding=kernel_size // 2))
|
||||||
|
@ -52,6 +50,17 @@ class ConvLayerNorm(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class InvConvNear(nn.Module):
|
class InvConvNear(nn.Module):
|
||||||
|
"""Inversible Convolution with splitting.
|
||||||
|
Args:
|
||||||
|
channels (int): input and output channels.
|
||||||
|
num_splits (int): number of splits, also H and W of conv layer.
|
||||||
|
no_jacobian (bool): enable/disable jacobian computations.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Split the input into groups of size self.num_splits and
|
||||||
|
perform 1x1 convolution separately. Cast 1x1 conv operation
|
||||||
|
to 2d by reshaping the input for efficiency.
|
||||||
|
"""
|
||||||
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
|
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_splits % 2 == 0
|
assert num_splits % 2 == 0
|
||||||
|
@ -67,9 +76,10 @@ class InvConvNear(nn.Module):
|
||||||
self.weight = nn.Parameter(w_init)
|
self.weight = nn.Parameter(w_init)
|
||||||
|
|
||||||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||||
"""Split the input into groups of size self.num_splits and
|
"""
|
||||||
perform 1x1 convolution separately. Cast 1x1 conv operation
|
Shapes:
|
||||||
to 2d by reshaping the input for efficiency.
|
x: B x C x T
|
||||||
|
x_mask: B x 1 x T
|
||||||
"""
|
"""
|
||||||
|
|
||||||
b, c, t = x.size()
|
b, c, t = x.size()
|
||||||
|
@ -112,6 +122,25 @@ class InvConvNear(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class CouplingBlock(nn.Module):
|
class CouplingBlock(nn.Module):
|
||||||
|
"""Glow Affine Coupling block.
|
||||||
|
For details https://arxiv.org/pdf/1811.00002.pdf
|
||||||
|
|
||||||
|
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||||
|
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input tensor channels.
|
||||||
|
hidden_channels (int): number of hidden channels.
|
||||||
|
kernel_size (int): WaveNet filter kernel size.
|
||||||
|
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||||
|
num_layers (int): number of WaveNet layers.
|
||||||
|
c_in_channels (int): number of conditioning input channels.
|
||||||
|
dropout_p (int): wavenet dropout rate.
|
||||||
|
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It does not use conditional inputs differently from WaveGlow.
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
|
@ -130,21 +159,28 @@ class CouplingBlock(nn.Module):
|
||||||
self.c_in_channels = c_in_channels
|
self.c_in_channels = c_in_channels
|
||||||
self.dropout_p = dropout_p
|
self.dropout_p = dropout_p
|
||||||
self.sigmoid_scale = sigmoid_scale
|
self.sigmoid_scale = sigmoid_scale
|
||||||
|
# input layer
|
||||||
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
||||||
start = torch.nn.utils.weight_norm(start)
|
start = torch.nn.utils.weight_norm(start)
|
||||||
self.start = start
|
self.start = start
|
||||||
|
# output layer
|
||||||
# Initializing last layer to 0 makes the affine coupling layers
|
# Initializing last layer to 0 makes the affine coupling layers
|
||||||
# do nothing at first. This helps with training stability
|
# do nothing at first. This helps with training stability
|
||||||
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
||||||
end.weight.data.zero_()
|
end.weight.data.zero_()
|
||||||
end.bias.data.zero_()
|
end.bias.data.zero_()
|
||||||
self.end = end
|
self.end = end
|
||||||
|
# coupling layers
|
||||||
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate,
|
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate,
|
||||||
num_layers, c_in_channels, dropout_p)
|
num_layers, c_in_channels, dropout_p)
|
||||||
|
|
||||||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
x: B x C x T
|
||||||
|
x_mask: B x 1 x T
|
||||||
|
g: B x C x 1
|
||||||
|
"""
|
||||||
if x_mask is None:
|
if x_mask is None:
|
||||||
x_mask = 1
|
x_mask = 1
|
||||||
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
||||||
|
@ -154,17 +190,17 @@ class CouplingBlock(nn.Module):
|
||||||
out = self.end(x)
|
out = self.end(x)
|
||||||
|
|
||||||
z_0 = x_0
|
z_0 = x_0
|
||||||
m = out[:, :self.in_channels // 2, :]
|
t = out[:, :self.in_channels // 2, :]
|
||||||
logs = out[:, self.in_channels // 2:, :]
|
s = out[:, self.in_channels // 2:, :]
|
||||||
if self.sigmoid_scale:
|
if self.sigmoid_scale:
|
||||||
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
s = torch.log(1e-6 + torch.sigmoid(s + 2))
|
||||||
|
|
||||||
if reverse:
|
if reverse:
|
||||||
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
z_1 = (x_1 - t) * torch.exp(-s) * x_mask
|
||||||
logdet = None
|
logdet = None
|
||||||
else:
|
else:
|
||||||
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
z_1 = (t + torch.exp(s) * x_1) * x_mask
|
||||||
logdet = torch.sum(logs * x_mask, [1, 2])
|
logdet = torch.sum(s * x_mask, [1, 2])
|
||||||
|
|
||||||
z = torch.cat([z_0, z_1], 1)
|
z = torch.cat([z_0, z_1], 1)
|
||||||
return z, logdet
|
return z, logdet
|
||||||
|
|
Loading…
Reference in New Issue