mirror of https://github.com/coqui-ai/TTS.git
mass refactoring and update
This commit is contained in:
parent
1d961d6f8a
commit
79c841ccd3
|
@ -391,6 +391,10 @@ class RelativePositionTransformer(nn.Module):
|
|||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.dropout(y)
|
||||
|
||||
if (i + 1) == self.num_layers and hasattr(self, 'proj'):
|
||||
x = self.proj(x)
|
||||
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
|
|
@ -1,9 +1,136 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
|
||||
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
|
||||
from TTS.tts.layers.generic.wavenet import WNBlocks
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
|
||||
|
||||
class WaveNetDecoder(nn.Module):
|
||||
"""WaveNet based decoder with a prenet and a postnet.
|
||||
|
||||
prenet: conv1d_1x1
|
||||
postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1
|
||||
|
||||
TODO: Integrate speaker conditioning vector.
|
||||
|
||||
Note:
|
||||
default wavenet parameters;
|
||||
params = {
|
||||
"num_blocks": 12,
|
||||
"hidden_channels":192,
|
||||
"kernel_size": 5,
|
||||
"dilation_rate": 1,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.05
|
||||
}
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of hidden channels for prenet and postnet.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params):
|
||||
super().__init__()
|
||||
# prenet
|
||||
self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1)
|
||||
# wavenet layers
|
||||
self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params)
|
||||
# postnet
|
||||
self.postnet = [
|
||||
torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv1d(hidden_channels, out_channels, 1),
|
||||
]
|
||||
self.postnet = nn.Sequential(*self.postnet)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None):
|
||||
x = self.prenet(x) * x_mask
|
||||
x = self.wn(x, x_mask, g)
|
||||
o = self.postnet(x) * x_mask
|
||||
return o
|
||||
|
||||
|
||||
class RelativePositionTransformerDecoder(nn.Module):
|
||||
"""Decoder with Relative Positional Transformer.
|
||||
|
||||
Note:
|
||||
Default params
|
||||
params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 8,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
}
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of hidden channels including Transformer layers.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
|
||||
super().__init__()
|
||||
self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1)
|
||||
self.rel_pos_transformer = RelativePositionTransformer(
|
||||
in_channels, out_channels, hidden_channels, **params)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
o = self.prenet(x) * x_mask
|
||||
o = self.rel_pos_transformer(o, x_mask)
|
||||
return o
|
||||
|
||||
|
||||
class ResidualConv1dBNDecoder(nn.Module):
|
||||
"""Residual Convolutional Decoder as in the original Speedy Speech paper
|
||||
|
||||
TODO: Integrate speaker conditioning vector.
|
||||
|
||||
Note:
|
||||
Default params
|
||||
params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
}
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
super().__init__()
|
||||
self.res_conv_block = ResidualConv1dBNBlock(in_channels,
|
||||
hidden_channels,
|
||||
hidden_channels, **params)
|
||||
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.postnet = nn.Sequential(
|
||||
Conv1dBNBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
params['kernel_size'],
|
||||
1,
|
||||
num_conv_blocks=2),
|
||||
nn.Conv1d(hidden_channels, out_channels, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
o = self.res_conv_block(x, x_mask)
|
||||
o = self.post_conv(o) + x
|
||||
return self.postnet(o) * x_mask
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Decodes the expanded phoneme encoding into spectrograms
|
||||
Args:
|
||||
|
@ -15,39 +142,8 @@ class Decoder(nn.Module):
|
|||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
|
||||
Note:
|
||||
Default decoder_params...
|
||||
|
||||
for 'transformer'
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 8,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
},
|
||||
|
||||
for 'residual_conv_bn'
|
||||
decoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
}
|
||||
|
||||
for 'wavenet'
|
||||
decoder_params = {
|
||||
"num_blocks": 12,
|
||||
"hidden_channels":192,
|
||||
"kernel_size": 5,
|
||||
"dilation_rate": 1,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.05
|
||||
}
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -62,28 +158,35 @@ class Decoder(nn.Module):
|
|||
},
|
||||
c_in_channels=0):
|
||||
super().__init__()
|
||||
self.in_channels = in_hidden_channels
|
||||
self.hidden_channels = in_hidden_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if decoder_type == 'transformer':
|
||||
self.decoder = RelativePositionTransformer(self.hidden_channels, **decoder_params)
|
||||
self.decoder = RelativePositionTransformerDecoder(
|
||||
in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type == 'residual_conv_bn':
|
||||
self.decoder = ResidualConvBNBlock(self.hidden_channels,
|
||||
**decoder_params)
|
||||
self.decoder = ResidualConv1dBNDecoder(
|
||||
in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type == 'wavenet':
|
||||
self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params)
|
||||
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
c_in_channels=c_in_channels,
|
||||
params=decoder_params)
|
||||
else:
|
||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||
|
||||
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, 1)
|
||||
self.post_net = nn.Sequential(
|
||||
ConvBNBlock(self.hidden_channels, decoder_params['kernel_size'], 1, num_conv_blocks=2),
|
||||
nn.Conv1d(self.hidden_channels, out_channels, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
Args:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
g: [B, C_g, 1]
|
||||
"""
|
||||
# TODO: implement multi-speaker
|
||||
o = self.decoder(x, x_mask)
|
||||
o = self.post_conv(o) + x
|
||||
return self.post_net(o) * x_mask
|
||||
o = self.decoder(x, x_mask, g)
|
||||
return o
|
|
@ -1,6 +1,6 @@
|
|||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.generic.res_conv_bn import ConvBN
|
||||
from TTS.tts.layers.generic.res_conv_bn import Conv1dBN
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
|
@ -21,9 +21,9 @@ class DurationPredictor(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
ConvBN(hidden_channels, 4, 1),
|
||||
ConvBN(hidden_channels, 3, 1),
|
||||
ConvBN(hidden_channels, 1, 1),
|
||||
Conv1dBN(hidden_channels, hidden_channels, 4, 1),
|
||||
Conv1dBN(hidden_channels, hidden_channels, 3, 1),
|
||||
Conv1dBN(hidden_channels, hidden_channels, 1, 1),
|
||||
nn.Conv1d(hidden_channels, 1, 1)
|
||||
])
|
||||
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
|
@ -18,12 +17,13 @@ class PositionalEncoding(nn.Module):
|
|||
def __init__(self, channels, dropout=0.0, max_len=5000):
|
||||
super().__init__()
|
||||
if channels % 2 != 0:
|
||||
raise ValueError("Cannot use sin/cos positional encoding with "
|
||||
"odd channels (got channels={:d})".format(channels))
|
||||
raise ValueError(
|
||||
"Cannot use sin/cos positional encoding with "
|
||||
"odd channels (got channels={:d})".format(channels))
|
||||
pe = torch.zeros(max_len, channels)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
|
||||
-(math.log(10000.0) / channels)))
|
||||
-(math.log(10000.0) / channels)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(1, 2)
|
||||
|
@ -59,9 +59,77 @@ class PositionalEncoding(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class RelativePositionTransformerEncoder(nn.Module):
|
||||
"""Speedy speech encoder built on Transformer with Relative Position encoding.
|
||||
|
||||
TODO: Integrate speaker conditioning vector.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of hidden channels
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
super().__init__()
|
||||
self.prenet = ResidualConv1dBNBlock(in_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_res_blocks=3,
|
||||
num_conv_blocks=1,
|
||||
dilations=[1, 1, 1]
|
||||
)
|
||||
self.rel_pos_transformer = RelativePositionTransformer(
|
||||
hidden_channels, out_channels, hidden_channels, **params)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
o = self.prenet(x) * x_mask
|
||||
o = self.rel_pos_transformer(o, x_mask)
|
||||
return o
|
||||
|
||||
|
||||
class ResidualConv1dBNEncoder(nn.Module):
|
||||
"""Residual Convolutional Encoder as in the original Speedy Speech paper
|
||||
|
||||
TODO: Integrate speaker conditioning vector.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of hidden channels
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
super().__init__()
|
||||
self.prenet = nn.Sequential(
|
||||
nn.Conv1d(in_channels, hidden_channels, 1),
|
||||
nn.ReLU())
|
||||
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels, **params)
|
||||
|
||||
self.postnet = nn.Sequential(*[
|
||||
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(hidden_channels),
|
||||
nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
])
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
o = self.prenet(x) * x_mask
|
||||
o = self.res_conv_block(o, x_mask)
|
||||
o = self.postnet(o + x) * x_mask
|
||||
return o * x_mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
"""Speedy-Speech encoder using Transformers or Residual BN Convs internally.
|
||||
"""Factory class for Speedy Speech encoder enables different encoder types internally.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
|
@ -114,29 +182,21 @@ class Encoder(nn.Module):
|
|||
|
||||
# init encoder
|
||||
if encoder_type.lower() == "transformer":
|
||||
# optional convolutional prenet
|
||||
self.pre = ConvLayerNorm(self.in_channels,
|
||||
self.hidden_channels,
|
||||
self.hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
# text encoder
|
||||
self.encoder = RelativePositionTransformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
self.encoder = RelativePositionTransformerEncoder(in_hidden_channels,
|
||||
out_channels,
|
||||
in_hidden_channels,
|
||||
encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
elif encoder_type.lower() == 'residual_conv_bn':
|
||||
self.pre = nn.Sequential(
|
||||
nn.Conv1d(self.in_channels, self.hidden_channels, 1),
|
||||
nn.ReLU())
|
||||
self.encoder = ResidualConvBNBlock(self.hidden_channels,
|
||||
**encoder_params)
|
||||
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
|
||||
out_channels,
|
||||
in_hidden_channels,
|
||||
encoder_params)
|
||||
else:
|
||||
raise NotImplementedError(' [!] encoder type not implemented.')
|
||||
raise NotImplementedError(' [!] unknown encoder type.')
|
||||
|
||||
# final projection layers
|
||||
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels,
|
||||
1)
|
||||
self.post_bn = nn.BatchNorm1d(self.hidden_channels)
|
||||
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
|
||||
|
||||
|
||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
|
@ -145,15 +205,5 @@ class Encoder(nn.Module):
|
|||
x_mask: [B, 1, T]
|
||||
g: [B, C, 1]
|
||||
"""
|
||||
# TODO: implement multi-speaker
|
||||
if self.encoder_type == 'transformer':
|
||||
o = self.pre(x, x_mask)
|
||||
else:
|
||||
o = self.pre(x) * x_mask
|
||||
o = self.encoder(o, x_mask)
|
||||
o = self.post_conv(o + x)
|
||||
o = F.relu(o)
|
||||
o = self.post_bn(o)
|
||||
o = self.post_conv2(o)
|
||||
# [B, C, T]
|
||||
o = self.encoder(x, x_mask)
|
||||
return o * x_mask
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
from torch import nn
|
||||
from .common_layers import Prenet, init_attn
|
||||
from .common_layers import Prenet
|
||||
from .attentions import init_attn
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from .common_layers import init_attn, Prenet, Linear
|
||||
from .common_layers import Prenet, Linear
|
||||
from .attentions import init_attn
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
|
|
|
@ -18,7 +18,7 @@ class Tacotron(TacotronAbstract):
|
|||
r (int): initial model reduction rate.
|
||||
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
|
||||
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
|
||||
attn_type (str, optional): attention type. Check ```TTS.tts.layers.common_layers.init_attn```. Defaults to 'original'.
|
||||
attn_type (str, optional): attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'.
|
||||
attn_win (bool, optional): enable/disable attention windowing.
|
||||
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
||||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||
|
|
|
@ -2,7 +2,7 @@ import tensorflow as tf
|
|||
from tensorflow import keras
|
||||
from TTS.tts.tf.utils.tf_utils import shape_list
|
||||
from TTS.tts.tf.layers.common_layers import Prenet, Attention
|
||||
# from tensorflow_addons.seq2seq import AttentionWrapper
|
||||
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
|
|
|
@ -50,10 +50,44 @@ def test_decoder():
|
|||
input_mask = torch.unsqueeze(
|
||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||
|
||||
# residual bn conv decoder
|
||||
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
# transformer decoder
|
||||
layer = Decoder(out_channels=11,
|
||||
in_hidden_channels=128,
|
||||
decoder_type='transformer',
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 8,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
|
||||
# wavenet decoder
|
||||
layer = Decoder(out_channels=11,
|
||||
in_hidden_channels=128,
|
||||
decoder_type='wavenet',
|
||||
decoder_params={
|
||||
"num_blocks": 12,
|
||||
"hidden_channels": 192,
|
||||
"kernel_size": 5,
|
||||
"dilation_rate": 1,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.05
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
|
||||
|
||||
def test_duration_predictor():
|
||||
input_dummy = torch.rand(8, 128, 27).to(device)
|
||||
|
|
Loading…
Reference in New Issue