mass refactoring and update

This commit is contained in:
erogol 2021-01-11 17:26:58 +01:00
parent 1d961d6f8a
commit 79c841ccd3
10 changed files with 287 additions and 94 deletions

View File

@ -391,6 +391,10 @@ class RelativePositionTransformer(nn.Module):
y = self.ffn_layers[i](x, x_mask)
y = self.dropout(y)
if (i + 1) == self.num_layers and hasattr(self, 'proj'):
x = self.proj(x)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x

View File

@ -1,9 +1,136 @@
import torch
from torch import nn
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
from TTS.tts.layers.generic.wavenet import WNBlocks
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
class WaveNetDecoder(nn.Module):
"""WaveNet based decoder with a prenet and a postnet.
prenet: conv1d_1x1
postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1
TODO: Integrate speaker conditioning vector.
Note:
default wavenet parameters;
params = {
"num_blocks": 12,
"hidden_channels":192,
"kernel_size": 5,
"dilation_rate": 1,
"num_layers": 4,
"dropout_p": 0.05
}
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels for prenet and postnet.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params):
super().__init__()
# prenet
self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1)
# wavenet layers
self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params)
# postnet
self.postnet = [
torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(hidden_channels, out_channels, 1),
]
self.postnet = nn.Sequential(*self.postnet)
def forward(self, x, x_mask=None, g=None):
x = self.prenet(x) * x_mask
x = self.wn(x, x_mask, g)
o = self.postnet(x) * x_mask
return o
class RelativePositionTransformerDecoder(nn.Module):
"""Decoder with Relative Positional Transformer.
Note:
Default params
params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 8,
"rel_attn_window_size": 4,
"input_length": None
}
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels including Transformer layers.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1)
self.rel_pos_transformer = RelativePositionTransformer(
in_channels, out_channels, hidden_channels, **params)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
o = self.prenet(x) * x_mask
o = self.rel_pos_transformer(o, x_mask)
return o
class ResidualConv1dBNDecoder(nn.Module):
"""Residual Convolutional Decoder as in the original Speedy Speech paper
TODO: Integrate speaker conditioning vector.
Note:
Default params
params = {
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
}
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.res_conv_block = ResidualConv1dBNBlock(in_channels,
hidden_channels,
hidden_channels, **params)
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
self.postnet = nn.Sequential(
Conv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels,
params['kernel_size'],
1,
num_conv_blocks=2),
nn.Conv1d(hidden_channels, out_channels, 1),
)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
o = self.res_conv_block(x, x_mask)
o = self.post_conv(o) + x
return self.postnet(o) * x_mask
class Decoder(nn.Module):
"""Decodes the expanded phoneme encoding into spectrograms
Args:
@ -15,39 +142,8 @@ class Decoder(nn.Module):
Shapes:
- input: (B, C, T)
Note:
Default decoder_params...
for 'transformer'
decoder_params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 8,
"rel_attn_window_size": 4,
"input_length": None
},
for 'residual_conv_bn'
decoder_params = {
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
}
for 'wavenet'
decoder_params = {
"num_blocks": 12,
"hidden_channels":192,
"kernel_size": 5,
"dilation_rate": 1,
"num_layers": 4,
"dropout_p": 0.05
}
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
@ -62,28 +158,35 @@ class Decoder(nn.Module):
},
c_in_channels=0):
super().__init__()
self.in_channels = in_hidden_channels
self.hidden_channels = in_hidden_channels
self.out_channels = out_channels
if decoder_type == 'transformer':
self.decoder = RelativePositionTransformer(self.hidden_channels, **decoder_params)
self.decoder = RelativePositionTransformerDecoder(
in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
params=decoder_params)
elif decoder_type == 'residual_conv_bn':
self.decoder = ResidualConvBNBlock(self.hidden_channels,
**decoder_params)
self.decoder = ResidualConv1dBNDecoder(
in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
params=decoder_params)
elif decoder_type == 'wavenet':
self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params)
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
c_in_channels=c_in_channels,
params=decoder_params)
else:
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, 1)
self.post_net = nn.Sequential(
ConvBNBlock(self.hidden_channels, decoder_params['kernel_size'], 1, num_conv_blocks=2),
nn.Conv1d(self.hidden_channels, out_channels, 1),
)
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
"""
Args:
x: [B, C, T]
x_mask: [B, 1, T]
g: [B, C_g, 1]
"""
# TODO: implement multi-speaker
o = self.decoder(x, x_mask)
o = self.post_conv(o) + x
return self.post_net(o) * x_mask
o = self.decoder(x, x_mask, g)
return o

View File

@ -1,6 +1,6 @@
from torch import nn
from TTS.tts.layers.generic.res_conv_bn import ConvBN
from TTS.tts.layers.generic.res_conv_bn import Conv1dBN
class DurationPredictor(nn.Module):
@ -21,9 +21,9 @@ class DurationPredictor(nn.Module):
super().__init__()
self.layers = nn.ModuleList([
ConvBN(hidden_channels, 4, 1),
ConvBN(hidden_channels, 3, 1),
ConvBN(hidden_channels, 1, 1),
Conv1dBN(hidden_channels, hidden_channels, 4, 1),
Conv1dBN(hidden_channels, hidden_channels, 3, 1),
Conv1dBN(hidden_channels, hidden_channels, 1, 1),
nn.Conv1d(hidden_channels, 1, 1)
])

View File

@ -1,11 +1,10 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
class PositionalEncoding(nn.Module):
@ -18,12 +17,13 @@ class PositionalEncoding(nn.Module):
def __init__(self, channels, dropout=0.0, max_len=5000):
super().__init__()
if channels % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd channels (got channels={:d})".format(channels))
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd channels (got channels={:d})".format(channels))
pe = torch.zeros(max_len, channels)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
-(math.log(10000.0) / channels)))
-(math.log(10000.0) / channels)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0).transpose(1, 2)
@ -59,9 +59,77 @@ class PositionalEncoding(nn.Module):
return x
class RelativePositionTransformerEncoder(nn.Module):
"""Speedy speech encoder built on Transformer with Relative Position encoding.
TODO: Integrate speaker conditioning vector.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = ResidualConv1dBNBlock(in_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_res_blocks=3,
num_conv_blocks=1,
dilations=[1, 1, 1]
)
self.rel_pos_transformer = RelativePositionTransformer(
hidden_channels, out_channels, hidden_channels, **params)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
x_mask = 1
o = self.prenet(x) * x_mask
o = self.rel_pos_transformer(o, x_mask)
return o
class ResidualConv1dBNEncoder(nn.Module):
"""Residual Convolutional Encoder as in the original Speedy Speech paper
TODO: Integrate speaker conditioning vector.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = nn.Sequential(
nn.Conv1d(in_channels, hidden_channels, 1),
nn.ReLU())
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels, **params)
self.postnet = nn.Sequential(*[
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU(),
nn.BatchNorm1d(hidden_channels),
nn.Conv1d(hidden_channels, out_channels, 1)
])
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
x_mask = 1
o = self.prenet(x) * x_mask
o = self.res_conv_block(o, x_mask)
o = self.postnet(o + x) * x_mask
return o * x_mask
class Encoder(nn.Module):
# pylint: disable=dangerous-default-value
"""Speedy-Speech encoder using Transformers or Residual BN Convs internally.
"""Factory class for Speedy Speech encoder enables different encoder types internally.
Args:
num_chars (int): number of characters.
@ -114,29 +182,21 @@ class Encoder(nn.Module):
# init encoder
if encoder_type.lower() == "transformer":
# optional convolutional prenet
self.pre = ConvLayerNorm(self.in_channels,
self.hidden_channels,
self.hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
# text encoder
self.encoder = RelativePositionTransformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
self.encoder = RelativePositionTransformerEncoder(in_hidden_channels,
out_channels,
in_hidden_channels,
encoder_params) # pylint: disable=unexpected-keyword-arg
elif encoder_type.lower() == 'residual_conv_bn':
self.pre = nn.Sequential(
nn.Conv1d(self.in_channels, self.hidden_channels, 1),
nn.ReLU())
self.encoder = ResidualConvBNBlock(self.hidden_channels,
**encoder_params)
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
out_channels,
in_hidden_channels,
encoder_params)
else:
raise NotImplementedError(' [!] encoder type not implemented.')
raise NotImplementedError(' [!] unknown encoder type.')
# final projection layers
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels,
1)
self.post_bn = nn.BatchNorm1d(self.hidden_channels)
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
"""
@ -145,15 +205,5 @@ class Encoder(nn.Module):
x_mask: [B, 1, T]
g: [B, C, 1]
"""
# TODO: implement multi-speaker
if self.encoder_type == 'transformer':
o = self.pre(x, x_mask)
else:
o = self.pre(x) * x_mask
o = self.encoder(o, x_mask)
o = self.post_conv(o + x)
o = F.relu(o)
o = self.post_bn(o)
o = self.post_conv2(o)
# [B, C, T]
o = self.encoder(x, x_mask)
return o * x_mask

View File

@ -1,7 +1,8 @@
# coding: utf-8
import torch
from torch import nn
from .common_layers import Prenet, init_attn
from .common_layers import Prenet
from .attentions import init_attn
class BatchNormConv1d(nn.Module):

View File

@ -1,7 +1,8 @@
import torch
from torch import nn
from torch.nn import functional as F
from .common_layers import init_attn, Prenet, Linear
from .common_layers import Prenet, Linear
from .attentions import init_attn
# NOTE: linter has a problem with the current TF release
#pylint: disable=no-value-for-parameter

View File

@ -18,7 +18,7 @@ class Tacotron(TacotronAbstract):
r (int): initial model reduction rate.
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
attn_type (str, optional): attention type. Check ```TTS.tts.layers.common_layers.init_attn```. Defaults to 'original'.
attn_type (str, optional): attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'.
attn_win (bool, optional): enable/disable attention windowing.
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".

View File

@ -2,7 +2,7 @@ import tensorflow as tf
from tensorflow import keras
from TTS.tts.tf.utils.tf_utils import shape_list
from TTS.tts.tf.layers.common_layers import Prenet, Attention
# from tensorflow_addons.seq2seq import AttentionWrapper
# NOTE: linter has a problem with the current TF release
#pylint: disable=no-value-for-parameter

View File

@ -37,7 +37,7 @@
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
// VOCABULARY PARAMETERS

View File

@ -50,10 +50,44 @@ def test_decoder():
input_mask = torch.unsqueeze(
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
# residual bn conv decoder
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# transformer decoder
layer = Decoder(out_channels=11,
in_hidden_channels=128,
decoder_type='transformer',
decoder_params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 8,
"rel_attn_window_size": 4,
"input_length": None
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# wavenet decoder
layer = Decoder(out_channels=11,
in_hidden_channels=128,
decoder_type='wavenet',
decoder_params={
"num_blocks": 12,
"hidden_channels": 192,
"kernel_size": 5,
"dilation_rate": 1,
"num_layers": 4,
"dropout_p": 0.05
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
def test_duration_predictor():
input_dummy = torch.rand(8, 128, 27).to(device)