diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index 2a2ff189..f130e26a 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -7,6 +7,11 @@ from ..generic.normalization import LayerNorm class ConvLayerNorm(nn.Module): + """Residual Convolution with LayerNorm + x -> conv1d -> layer_norm -> dropout -> + -> o + | | + ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯ + """ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): super().__init__() @@ -22,16 +27,9 @@ class ConvLayerNorm(nn.Module): self.conv_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() - self.conv_layers.append( - nn.Conv1d(in_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - - for _ in range(num_layers - 1): + for idx in range(num_layers - 1): self.conv_layers.append( - nn.Conv1d(hidden_channels, + nn.Conv1d(in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) @@ -52,6 +50,17 @@ class ConvLayerNorm(nn.Module): class InvConvNear(nn.Module): + """Inversible Convolution with splitting. + Args: + channels (int): input and output channels. + num_splits (int): number of splits, also H and W of conv layer. + no_jacobian (bool): enable/disable jacobian computations. + + Note: + Split the input into groups of size self.num_splits and + perform 1x1 convolution separately. Cast 1x1 conv operation + to 2d by reshaping the input for efficiency. + """ def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument super().__init__() assert num_splits % 2 == 0 @@ -67,9 +76,10 @@ class InvConvNear(nn.Module): self.weight = nn.Parameter(w_init) def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument - """Split the input into groups of size self.num_splits and - perform 1x1 convolution separately. Cast 1x1 conv operation - to 2d by reshaping the input for efficiency. + """ + Shapes: + x: B x C x T + x_mask: B x 1 x T """ b, c, t = x.size() @@ -112,6 +122,25 @@ class InvConvNear(nn.Module): class CouplingBlock(nn.Module): + """Glow Affine Coupling block. + For details https://arxiv.org/pdf/1811.00002.pdf + + x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o + '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of hidden channels. + kernel_size (int): WaveNet filter kernel size. + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_layers (int): number of WaveNet layers. + c_in_channels (int): number of conditioning input channels. + dropout_p (int): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. + + Note: + It does not use conditional inputs differently from WaveGlow. + """ def __init__(self, in_channels, hidden_channels, @@ -130,21 +159,28 @@ class CouplingBlock(nn.Module): self.c_in_channels = c_in_channels self.dropout_p = dropout_p self.sigmoid_scale = sigmoid_scale - + # input layer start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) start = torch.nn.utils.weight_norm(start) self.start = start + # output layer # Initializing last layer to 0 makes the affine coupling layers # do nothing at first. This helps with training stability end = torch.nn.Conv1d(hidden_channels, in_channels, 1) end.weight.data.zero_() end.bias.data.zero_() self.end = end - + # coupling layers self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument + """ + Shapes: + x: B x C x T + x_mask: B x 1 x T + g: B x C x 1 + """ if x_mask is None: x_mask = 1 x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] @@ -154,17 +190,17 @@ class CouplingBlock(nn.Module): out = self.end(x) z_0 = x_0 - m = out[:, :self.in_channels // 2, :] - logs = out[:, self.in_channels // 2:, :] + t = out[:, :self.in_channels // 2, :] + s = out[:, self.in_channels // 2:, :] if self.sigmoid_scale: - logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) + s = torch.log(1e-6 + torch.sigmoid(s + 2)) if reverse: - z_1 = (x_1 - m) * torch.exp(-logs) * x_mask + z_1 = (x_1 - t) * torch.exp(-s) * x_mask logdet = None else: - z_1 = (m + torch.exp(logs) * x_1) * x_mask - logdet = torch.sum(logs * x_mask, [1, 2]) + z_1 = (t + torch.exp(s) * x_1) * x_mask + logdet = torch.sum(s * x_mask, [1, 2]) z = torch.cat([z_0, z_1], 1) return z, logdet