mirror of https://github.com/coqui-ai/TTS.git
rename convbnblocks and handle none mask
This commit is contained in:
parent
921fa5db92
commit
6e9043c5d2
|
@ -14,15 +14,27 @@ class ZeroTemporalPad(nn.Module):
|
|||
return self.pad_layer(x)
|
||||
|
||||
|
||||
class ConvBN(nn.Module):
|
||||
def __init__(self, channels, kernel_size, dilation):
|
||||
class Conv1dBN(nn.Module):
|
||||
"""1d convolutional with batch norm.
|
||||
conv1d -> relu -> BN blocks.
|
||||
|
||||
Note:
|
||||
Batch normalization is applied after ReLU regarding the original implementation.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
kernel_size (int): kernel size for convolutional filters.
|
||||
dilation (int): dilation for convolution layers.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, dilation):
|
||||
super().__init__()
|
||||
padding = (dilation * (kernel_size - 1))
|
||||
pad_s = padding // 2
|
||||
pad_e = padding - pad_s
|
||||
self.conv1d = nn.Conv1d(channels, channels, kernel_size, dilation=dilation)
|
||||
self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
|
||||
self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0)) # uneven left and right padding
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.norm = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
o = self.conv1d(x)
|
||||
|
@ -32,15 +44,27 @@ class ConvBN(nn.Module):
|
|||
return o
|
||||
|
||||
|
||||
class ConvBNBlock(nn.Module):
|
||||
"""Implements conv->PReLU->norm n-times"""
|
||||
class Conv1dBNBlock(nn.Module):
|
||||
"""1d convolutional block with batch norm. It is a set of conv1d -> relu -> BN blocks.
|
||||
|
||||
def __init__(self, channels, kernel_size, dilation, num_conv_blocks=2):
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of inner convolution channels.
|
||||
kernel_size (int): kernel size for convolutional filters.
|
||||
dilation (int): dilation for convolution layers.
|
||||
num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2):
|
||||
super().__init__()
|
||||
self.conv_bn_blocks = nn.Sequential(*[
|
||||
ConvBN(channels, kernel_size, dilation)
|
||||
for _ in range(num_conv_blocks)
|
||||
])
|
||||
self.conv_bn_blocks = []
|
||||
for idx in range(num_conv_blocks):
|
||||
layer = Conv1dBN(in_channels if idx == 0 else hidden_channels,
|
||||
out_channels if idx == (num_conv_blocks - 1) else hidden_channels,
|
||||
kernel_size,
|
||||
dilation)
|
||||
self.conv_bn_blocks.append(layer)
|
||||
self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
@ -50,16 +74,40 @@ class ConvBNBlock(nn.Module):
|
|||
return self.conv_bn_blocks(x)
|
||||
|
||||
|
||||
class ResidualConvBNBlock(nn.Module):
|
||||
def __init__(self, channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2):
|
||||
class ResidualConv1dBNBlock(nn.Module):
|
||||
"""Residual Convolutional Blocks with BN
|
||||
Each block has 'num_conv_block' conv layers and 'num_res_blocks' such blocks are connected
|
||||
with residual connections.
|
||||
|
||||
conv_block = (conv1d -> relu -> bn) x 'num_conv_blocks'
|
||||
residuak_conv_block = (x -> conv_block -> + ->) x 'num_res_blocks'
|
||||
' - - - - - - - - - ^
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of inner convolution channels.
|
||||
kernel_size (int): kernel size for convolutional filters.
|
||||
dilations (list): dilations for each convolution layer.
|
||||
num_res_blocks (int, optional): number of residual blocks. Defaults to 13.
|
||||
num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2):
|
||||
|
||||
super().__init__()
|
||||
assert len(dilations) == num_res_blocks
|
||||
self.res_blocks = nn.ModuleList()
|
||||
for dilation in dilations:
|
||||
block = ConvBNBlock(channels, kernel_size, dilation, num_conv_blocks)
|
||||
for idx, dilation in enumerate(dilations):
|
||||
block = Conv1dBNBlock(in_channels if idx==0 else hidden_channels,
|
||||
out_channels if (idx + 1) == len(dilations) else hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
num_conv_blocks)
|
||||
self.res_blocks.append(block)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
if x_mask is None:
|
||||
x_mask = 1.0
|
||||
o = x * x_mask
|
||||
for block in self.res_blocks:
|
||||
res = o
|
||||
|
|
Loading…
Reference in New Issue