update glow-tts layers and add some comments

This commit is contained in:
erogol 2021-01-05 14:25:12 +01:00
parent 29cf933831
commit 29f4329d7f
3 changed files with 84 additions and 27 deletions

View File

@ -60,9 +60,11 @@ class ResidualConvBNBlock(nn.Module):
self.res_blocks.append(block)
def forward(self, x, x_mask=None):
o = x
o = x * x_mask
for block in self.res_blocks:
res = o
o = block(o * x_mask if x_mask is not None else o)
o = block(o)
o = o + res
if x_mask is not None:
o = o * x_mask
return o

View File

@ -6,6 +6,13 @@ from TTS.tts.layers.generic.normalization import ActNorm
def squeeze(x, x_mask=None, num_sqz=2):
"""GlowTTS squeeze operation
Increase number of channels and reduce number of time steps
by the same factor.
Note:
each 's' is a n-dimensional vector.
[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]"""
b, c, t = x.size()
t = (t // num_sqz) * num_sqz
@ -23,6 +30,11 @@ def squeeze(x, x_mask=None, num_sqz=2):
def unsqueeze(x, x_mask=None, num_sqz=2):
"""GlowTTS unsqueeze operation
Note:
each 's' is a n-dimensional vector.
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]] """
b, c, t = x.size()
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
@ -40,7 +52,19 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
class Decoder(nn.Module):
"""Stack of Glow Modules"""
"""Stack of Glow Decoder Modules.
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
Args:
in_channels (int): channels of input tensor.
hidden_channels (int): hidden decoder channels.
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.)
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_flow_blocks (int): number of decoder blocks.
num_coupling_layers (int): number coupling layers. (number of wavenet layers.)
dropout_p (float): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
"""
def __init__(self,
in_channels,
hidden_channels,
@ -50,7 +74,7 @@ class Decoder(nn.Module):
num_coupling_layers,
dropout_p=0.,
num_splits=4,
num_sqz=2,
num_squeeze=2,
sigmoid_scale=False,
c_in_channels=0):
super().__init__()
@ -63,18 +87,18 @@ class Decoder(nn.Module):
self.num_coupling_layers = num_coupling_layers
self.dropout_p = dropout_p
self.num_splits = num_splits
self.num_sqz = num_sqz
self.num_squeeze = num_squeeze
self.sigmoid_scale = sigmoid_scale
self.c_in_channels = c_in_channels
self.flows = nn.ModuleList()
for _ in range(num_flow_blocks):
self.flows.append(ActNorm(channels=in_channels * num_sqz))
self.flows.append(ActNorm(channels=in_channels * num_squeeze))
self.flows.append(
InvConvNear(channels=in_channels * num_sqz,
InvConvNear(channels=in_channels * num_squeeze,
num_splits=num_splits))
self.flows.append(
CouplingBlock(in_channels * num_sqz,
CouplingBlock(in_channels * num_squeeze,
hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
@ -91,16 +115,16 @@ class Decoder(nn.Module):
flows = reversed(self.flows)
logdet_tot = None
if self.num_sqz > 1:
x, x_mask = squeeze(x, x_mask, self.num_sqz)
if self.num_squeeze > 1:
x, x_mask = squeeze(x, x_mask, self.num_squeeze)
for f in flows:
if not reverse:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
logdet_tot += logdet
else:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
if self.num_sqz > 1:
x, x_mask = unsqueeze(x, x_mask, self.num_sqz)
if self.num_squeeze > 1:
x, x_mask = unsqueeze(x, x_mask, self.num_squeeze)
return x, logdet_tot
def store_inverse(self):

View File

@ -2,7 +2,7 @@ import math
import torch
from torch import nn
from TTS.tts.layers.glow_tts.transformer import Transformer
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
@ -12,16 +12,20 @@ from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
class Encoder(nn.Module):
"""Glow-TTS encoder module. It uses Transformer with Relative Pos.Encoding
as in the original paper or GatedConvBlock as a faster alternative.
"""Glow-TTS encoder module.
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
|-> proj_var
|
|-> concat -> duration_predictor
speaker_embed
Args:
num_chars (int): number of characters.
out_channels (int): number of output channels.
hidden_channels (int): encoder's embedding size.
hidden_channels_ffn (int): transformer's feed-forward channels.
num_head (int): number of attention heads in transformer.
num_layers (int): number of transformer encoder stack.
kernel_size (int): kernel size for conv layers and duration predictor.
dropout_p (float): dropout rate for any dropout layer.
mean_only (bool): if True, output only mean values and use constant std.
@ -30,19 +34,49 @@ class Encoder(nn.Module):
Shapes:
- input: (B, T, C)
Notes:
suggested encoder params...
for encoder_type == 'rel_pos_transformer'
encoder_params={
'kernel_size':3,
'dropout_p': 0.1,
'num_layers': 6,
'num_heads': 2,
'hidden_channels_ffn': 768, # 4 times the hidden_channels
'input_length': None
}
for encoder_type == 'gated_conv'
encoder_params={
'kernel_size':5,
'dropout_p': 0.1,
'num_layers': 9,
}
for encoder_type == 'residual_conv_bn'
encoder_params={
"kernel_size": 4,
"dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1],
"num_conv_blocks": 2,
"num_res_blocks": 13
}
for encoder_type == 'time_depth_separable'
encoder_params={
"kernel_size": 5,
'num_layers': 9,
}
"""
def __init__(self,
num_chars,
out_channels,
hidden_channels,
hidden_channels_ffn,
hidden_channels_dp,
encoder_type,
num_heads,
num_layers,
dropout_p,
rel_attn_window_size=None,
input_length=None,
encoder_params,
dropout_p_dp=0.1,
mean_only=False,
use_prenet=True,
c_in_channels=0):
@ -51,11 +85,8 @@ class Encoder(nn.Module):
self.num_chars = num_chars
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.hidden_channels_ffn = hidden_channels_ffn
self.hidden_channels_dp = hidden_channels_dp
self.num_heads = num_heads
self.num_layers = num_layers
self.dropout_p = dropout_p
self.dropout_p_dp = dropout_p_dp
self.mean_only = mean_only
self.use_prenet = use_prenet
self.c_in_channels = c_in_channels