mirror of https://github.com/coqui-ai/TTS.git
update glow-tts layers and add some comments
This commit is contained in:
parent
29cf933831
commit
29f4329d7f
|
@ -60,9 +60,11 @@ class ResidualConvBNBlock(nn.Module):
|
||||||
self.res_blocks.append(block)
|
self.res_blocks.append(block)
|
||||||
|
|
||||||
def forward(self, x, x_mask=None):
|
def forward(self, x, x_mask=None):
|
||||||
o = x
|
o = x * x_mask
|
||||||
for block in self.res_blocks:
|
for block in self.res_blocks:
|
||||||
res = o
|
res = o
|
||||||
o = block(o * x_mask if x_mask is not None else o)
|
o = block(o)
|
||||||
o = o + res
|
o = o + res
|
||||||
|
if x_mask is not None:
|
||||||
|
o = o * x_mask
|
||||||
return o
|
return o
|
||||||
|
|
|
@ -6,6 +6,13 @@ from TTS.tts.layers.generic.normalization import ActNorm
|
||||||
|
|
||||||
|
|
||||||
def squeeze(x, x_mask=None, num_sqz=2):
|
def squeeze(x, x_mask=None, num_sqz=2):
|
||||||
|
"""GlowTTS squeeze operation
|
||||||
|
Increase number of channels and reduce number of time steps
|
||||||
|
by the same factor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
each 's' is a n-dimensional vector.
|
||||||
|
[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]"""
|
||||||
b, c, t = x.size()
|
b, c, t = x.size()
|
||||||
|
|
||||||
t = (t // num_sqz) * num_sqz
|
t = (t // num_sqz) * num_sqz
|
||||||
|
@ -23,6 +30,11 @@ def squeeze(x, x_mask=None, num_sqz=2):
|
||||||
|
|
||||||
|
|
||||||
def unsqueeze(x, x_mask=None, num_sqz=2):
|
def unsqueeze(x, x_mask=None, num_sqz=2):
|
||||||
|
"""GlowTTS unsqueeze operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
each 's' is a n-dimensional vector.
|
||||||
|
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]] """
|
||||||
b, c, t = x.size()
|
b, c, t = x.size()
|
||||||
|
|
||||||
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
|
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
|
||||||
|
@ -40,7 +52,19 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""Stack of Glow Modules"""
|
"""Stack of Glow Decoder Modules.
|
||||||
|
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): channels of input tensor.
|
||||||
|
hidden_channels (int): hidden decoder channels.
|
||||||
|
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.)
|
||||||
|
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||||
|
num_flow_blocks (int): number of decoder blocks.
|
||||||
|
num_coupling_layers (int): number coupling layers. (number of wavenet layers.)
|
||||||
|
dropout_p (float): wavenet dropout rate.
|
||||||
|
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
|
@ -50,7 +74,7 @@ class Decoder(nn.Module):
|
||||||
num_coupling_layers,
|
num_coupling_layers,
|
||||||
dropout_p=0.,
|
dropout_p=0.,
|
||||||
num_splits=4,
|
num_splits=4,
|
||||||
num_sqz=2,
|
num_squeeze=2,
|
||||||
sigmoid_scale=False,
|
sigmoid_scale=False,
|
||||||
c_in_channels=0):
|
c_in_channels=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -63,18 +87,18 @@ class Decoder(nn.Module):
|
||||||
self.num_coupling_layers = num_coupling_layers
|
self.num_coupling_layers = num_coupling_layers
|
||||||
self.dropout_p = dropout_p
|
self.dropout_p = dropout_p
|
||||||
self.num_splits = num_splits
|
self.num_splits = num_splits
|
||||||
self.num_sqz = num_sqz
|
self.num_squeeze = num_squeeze
|
||||||
self.sigmoid_scale = sigmoid_scale
|
self.sigmoid_scale = sigmoid_scale
|
||||||
self.c_in_channels = c_in_channels
|
self.c_in_channels = c_in_channels
|
||||||
|
|
||||||
self.flows = nn.ModuleList()
|
self.flows = nn.ModuleList()
|
||||||
for _ in range(num_flow_blocks):
|
for _ in range(num_flow_blocks):
|
||||||
self.flows.append(ActNorm(channels=in_channels * num_sqz))
|
self.flows.append(ActNorm(channels=in_channels * num_squeeze))
|
||||||
self.flows.append(
|
self.flows.append(
|
||||||
InvConvNear(channels=in_channels * num_sqz,
|
InvConvNear(channels=in_channels * num_squeeze,
|
||||||
num_splits=num_splits))
|
num_splits=num_splits))
|
||||||
self.flows.append(
|
self.flows.append(
|
||||||
CouplingBlock(in_channels * num_sqz,
|
CouplingBlock(in_channels * num_squeeze,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
dilation_rate=dilation_rate,
|
dilation_rate=dilation_rate,
|
||||||
|
@ -91,16 +115,16 @@ class Decoder(nn.Module):
|
||||||
flows = reversed(self.flows)
|
flows = reversed(self.flows)
|
||||||
logdet_tot = None
|
logdet_tot = None
|
||||||
|
|
||||||
if self.num_sqz > 1:
|
if self.num_squeeze > 1:
|
||||||
x, x_mask = squeeze(x, x_mask, self.num_sqz)
|
x, x_mask = squeeze(x, x_mask, self.num_squeeze)
|
||||||
for f in flows:
|
for f in flows:
|
||||||
if not reverse:
|
if not reverse:
|
||||||
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
||||||
logdet_tot += logdet
|
logdet_tot += logdet
|
||||||
else:
|
else:
|
||||||
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
||||||
if self.num_sqz > 1:
|
if self.num_squeeze > 1:
|
||||||
x, x_mask = unsqueeze(x, x_mask, self.num_sqz)
|
x, x_mask = unsqueeze(x, x_mask, self.num_squeeze)
|
||||||
return x, logdet_tot
|
return x, logdet_tot
|
||||||
|
|
||||||
def store_inverse(self):
|
def store_inverse(self):
|
||||||
|
|
|
@ -2,7 +2,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
|
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||||
|
@ -12,16 +12,20 @@ from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
"""Glow-TTS encoder module. It uses Transformer with Relative Pos.Encoding
|
"""Glow-TTS encoder module.
|
||||||
as in the original paper or GatedConvBlock as a faster alternative.
|
|
||||||
|
|
||||||
|
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
||||||
|
|
|
||||||
|
|-> proj_var
|
||||||
|
|
|
||||||
|
|-> concat -> duration_predictor
|
||||||
|
↑
|
||||||
|
speaker_embed
|
||||||
Args:
|
Args:
|
||||||
num_chars (int): number of characters.
|
num_chars (int): number of characters.
|
||||||
out_channels (int): number of output channels.
|
out_channels (int): number of output channels.
|
||||||
hidden_channels (int): encoder's embedding size.
|
hidden_channels (int): encoder's embedding size.
|
||||||
hidden_channels_ffn (int): transformer's feed-forward channels.
|
hidden_channels_ffn (int): transformer's feed-forward channels.
|
||||||
num_head (int): number of attention heads in transformer.
|
|
||||||
num_layers (int): number of transformer encoder stack.
|
|
||||||
kernel_size (int): kernel size for conv layers and duration predictor.
|
kernel_size (int): kernel size for conv layers and duration predictor.
|
||||||
dropout_p (float): dropout rate for any dropout layer.
|
dropout_p (float): dropout rate for any dropout layer.
|
||||||
mean_only (bool): if True, output only mean values and use constant std.
|
mean_only (bool): if True, output only mean values and use constant std.
|
||||||
|
@ -30,19 +34,49 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- input: (B, T, C)
|
- input: (B, T, C)
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
suggested encoder params...
|
||||||
|
|
||||||
|
for encoder_type == 'rel_pos_transformer'
|
||||||
|
encoder_params={
|
||||||
|
'kernel_size':3,
|
||||||
|
'dropout_p': 0.1,
|
||||||
|
'num_layers': 6,
|
||||||
|
'num_heads': 2,
|
||||||
|
'hidden_channels_ffn': 768, # 4 times the hidden_channels
|
||||||
|
'input_length': None
|
||||||
|
}
|
||||||
|
|
||||||
|
for encoder_type == 'gated_conv'
|
||||||
|
encoder_params={
|
||||||
|
'kernel_size':5,
|
||||||
|
'dropout_p': 0.1,
|
||||||
|
'num_layers': 9,
|
||||||
|
}
|
||||||
|
|
||||||
|
for encoder_type == 'residual_conv_bn'
|
||||||
|
encoder_params={
|
||||||
|
"kernel_size": 4,
|
||||||
|
"dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1],
|
||||||
|
"num_conv_blocks": 2,
|
||||||
|
"num_res_blocks": 13
|
||||||
|
}
|
||||||
|
|
||||||
|
for encoder_type == 'time_depth_separable'
|
||||||
|
encoder_params={
|
||||||
|
"kernel_size": 5,
|
||||||
|
'num_layers': 9,
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_chars,
|
num_chars,
|
||||||
out_channels,
|
out_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
hidden_channels_ffn,
|
|
||||||
hidden_channels_dp,
|
hidden_channels_dp,
|
||||||
encoder_type,
|
encoder_type,
|
||||||
num_heads,
|
encoder_params,
|
||||||
num_layers,
|
dropout_p_dp=0.1,
|
||||||
dropout_p,
|
|
||||||
rel_attn_window_size=None,
|
|
||||||
input_length=None,
|
|
||||||
mean_only=False,
|
mean_only=False,
|
||||||
use_prenet=True,
|
use_prenet=True,
|
||||||
c_in_channels=0):
|
c_in_channels=0):
|
||||||
|
@ -51,11 +85,8 @@ class Encoder(nn.Module):
|
||||||
self.num_chars = num_chars
|
self.num_chars = num_chars
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
self.hidden_channels_ffn = hidden_channels_ffn
|
|
||||||
self.hidden_channels_dp = hidden_channels_dp
|
self.hidden_channels_dp = hidden_channels_dp
|
||||||
self.num_heads = num_heads
|
self.dropout_p_dp = dropout_p_dp
|
||||||
self.num_layers = num_layers
|
|
||||||
self.dropout_p = dropout_p
|
|
||||||
self.mean_only = mean_only
|
self.mean_only = mean_only
|
||||||
self.use_prenet = use_prenet
|
self.use_prenet = use_prenet
|
||||||
self.c_in_channels = c_in_channels
|
self.c_in_channels = c_in_channels
|
||||||
|
|
Loading…
Reference in New Issue