mirror of https://github.com/coqui-ai/TTS.git
mass refactoring and update
This commit is contained in:
parent
1d961d6f8a
commit
79c841ccd3
|
@ -391,6 +391,10 @@ class RelativePositionTransformer(nn.Module):
|
||||||
|
|
||||||
y = self.ffn_layers[i](x, x_mask)
|
y = self.ffn_layers[i](x, x_mask)
|
||||||
y = self.dropout(y)
|
y = self.dropout(y)
|
||||||
|
|
||||||
|
if (i + 1) == self.num_layers and hasattr(self, 'proj'):
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
x = self.norm_layers_2[i](x + y)
|
x = self.norm_layers_2[i](x + y)
|
||||||
x = x * x_mask
|
x = x * x_mask
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -1,9 +1,136 @@
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
|
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
|
||||||
from TTS.tts.layers.generic.wavenet import WNBlocks
|
from TTS.tts.layers.generic.wavenet import WNBlocks
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class WaveNetDecoder(nn.Module):
|
||||||
|
"""WaveNet based decoder with a prenet and a postnet.
|
||||||
|
|
||||||
|
prenet: conv1d_1x1
|
||||||
|
postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1
|
||||||
|
|
||||||
|
TODO: Integrate speaker conditioning vector.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
default wavenet parameters;
|
||||||
|
params = {
|
||||||
|
"num_blocks": 12,
|
||||||
|
"hidden_channels":192,
|
||||||
|
"kernel_size": 5,
|
||||||
|
"dilation_rate": 1,
|
||||||
|
"num_layers": 4,
|
||||||
|
"dropout_p": 0.05
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels.
|
||||||
|
out_channels (int): number of output channels.
|
||||||
|
hidden_channels (int): number of hidden channels for prenet and postnet.
|
||||||
|
params (dict): dictionary for residual convolutional blocks.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params):
|
||||||
|
super().__init__()
|
||||||
|
# prenet
|
||||||
|
self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1)
|
||||||
|
# wavenet layers
|
||||||
|
self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params)
|
||||||
|
# postnet
|
||||||
|
self.postnet = [
|
||||||
|
torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv1d(hidden_channels, out_channels, 1),
|
||||||
|
]
|
||||||
|
self.postnet = nn.Sequential(*self.postnet)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None, g=None):
|
||||||
|
x = self.prenet(x) * x_mask
|
||||||
|
x = self.wn(x, x_mask, g)
|
||||||
|
o = self.postnet(x) * x_mask
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionTransformerDecoder(nn.Module):
|
||||||
|
"""Decoder with Relative Positional Transformer.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Default params
|
||||||
|
params={
|
||||||
|
'hidden_channels_ffn': 128,
|
||||||
|
'num_heads': 2,
|
||||||
|
"kernel_size": 3,
|
||||||
|
"dropout_p": 0.1,
|
||||||
|
"num_layers": 8,
|
||||||
|
"rel_attn_window_size": 4,
|
||||||
|
"input_length": None
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels.
|
||||||
|
out_channels (int): number of output channels.
|
||||||
|
hidden_channels (int): number of hidden channels including Transformer layers.
|
||||||
|
params (dict): dictionary for residual convolutional blocks.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1)
|
||||||
|
self.rel_pos_transformer = RelativePositionTransformer(
|
||||||
|
in_channels, out_channels, hidden_channels, **params)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||||
|
o = self.prenet(x) * x_mask
|
||||||
|
o = self.rel_pos_transformer(o, x_mask)
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConv1dBNDecoder(nn.Module):
|
||||||
|
"""Residual Convolutional Decoder as in the original Speedy Speech paper
|
||||||
|
|
||||||
|
TODO: Integrate speaker conditioning vector.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Default params
|
||||||
|
params = {
|
||||||
|
"kernel_size": 4,
|
||||||
|
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||||
|
"num_conv_blocks": 2,
|
||||||
|
"num_res_blocks": 17
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels.
|
||||||
|
out_channels (int): number of output channels.
|
||||||
|
hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers.
|
||||||
|
params (dict): dictionary for residual convolutional blocks.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||||
|
super().__init__()
|
||||||
|
self.res_conv_block = ResidualConv1dBNBlock(in_channels,
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels, **params)
|
||||||
|
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||||
|
self.postnet = nn.Sequential(
|
||||||
|
Conv1dBNBlock(hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
params['kernel_size'],
|
||||||
|
1,
|
||||||
|
num_conv_blocks=2),
|
||||||
|
nn.Conv1d(hidden_channels, out_channels, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||||
|
o = self.res_conv_block(x, x_mask)
|
||||||
|
o = self.post_conv(o) + x
|
||||||
|
return self.postnet(o) * x_mask
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""Decodes the expanded phoneme encoding into spectrograms
|
"""Decodes the expanded phoneme encoding into spectrograms
|
||||||
Args:
|
Args:
|
||||||
|
@ -15,39 +142,8 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- input: (B, C, T)
|
- input: (B, C, T)
|
||||||
|
|
||||||
Note:
|
|
||||||
Default decoder_params...
|
|
||||||
|
|
||||||
for 'transformer'
|
|
||||||
decoder_params={
|
|
||||||
'hidden_channels_ffn': 128,
|
|
||||||
'num_heads': 2,
|
|
||||||
"kernel_size": 3,
|
|
||||||
"dropout_p": 0.1,
|
|
||||||
"num_layers": 8,
|
|
||||||
"rel_attn_window_size": 4,
|
|
||||||
"input_length": None
|
|
||||||
},
|
|
||||||
|
|
||||||
for 'residual_conv_bn'
|
|
||||||
decoder_params = {
|
|
||||||
"kernel_size": 4,
|
|
||||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
|
||||||
"num_conv_blocks": 2,
|
|
||||||
"num_res_blocks": 17
|
|
||||||
}
|
|
||||||
|
|
||||||
for 'wavenet'
|
|
||||||
decoder_params = {
|
|
||||||
"num_blocks": 12,
|
|
||||||
"hidden_channels":192,
|
|
||||||
"kernel_size": 5,
|
|
||||||
"dilation_rate": 1,
|
|
||||||
"num_layers": 4,
|
|
||||||
"dropout_p": 0.05
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -62,28 +158,35 @@ class Decoder(nn.Module):
|
||||||
},
|
},
|
||||||
c_in_channels=0):
|
c_in_channels=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_hidden_channels
|
|
||||||
self.hidden_channels = in_hidden_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
if decoder_type == 'transformer':
|
if decoder_type == 'transformer':
|
||||||
self.decoder = RelativePositionTransformer(self.hidden_channels, **decoder_params)
|
self.decoder = RelativePositionTransformerDecoder(
|
||||||
|
in_channels=in_hidden_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
hidden_channels=in_hidden_channels,
|
||||||
|
params=decoder_params)
|
||||||
elif decoder_type == 'residual_conv_bn':
|
elif decoder_type == 'residual_conv_bn':
|
||||||
self.decoder = ResidualConvBNBlock(self.hidden_channels,
|
self.decoder = ResidualConv1dBNDecoder(
|
||||||
**decoder_params)
|
in_channels=in_hidden_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
hidden_channels=in_hidden_channels,
|
||||||
|
params=decoder_params)
|
||||||
elif decoder_type == 'wavenet':
|
elif decoder_type == 'wavenet':
|
||||||
self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params)
|
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
hidden_channels=in_hidden_channels,
|
||||||
|
c_in_channels=c_in_channels,
|
||||||
|
params=decoder_params)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||||
|
|
||||||
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, 1)
|
|
||||||
self.post_net = nn.Sequential(
|
|
||||||
ConvBNBlock(self.hidden_channels, decoder_params['kernel_size'], 1, num_conv_blocks=2),
|
|
||||||
nn.Conv1d(self.hidden_channels, out_channels, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, C, T]
|
||||||
|
x_mask: [B, 1, T]
|
||||||
|
g: [B, C_g, 1]
|
||||||
|
"""
|
||||||
# TODO: implement multi-speaker
|
# TODO: implement multi-speaker
|
||||||
o = self.decoder(x, x_mask)
|
o = self.decoder(x, x_mask, g)
|
||||||
o = self.post_conv(o) + x
|
return o
|
||||||
return self.post_net(o) * x_mask
|
|
|
@ -1,6 +1,6 @@
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.tts.layers.generic.res_conv_bn import ConvBN
|
from TTS.tts.layers.generic.res_conv_bn import Conv1dBN
|
||||||
|
|
||||||
|
|
||||||
class DurationPredictor(nn.Module):
|
class DurationPredictor(nn.Module):
|
||||||
|
@ -21,9 +21,9 @@ class DurationPredictor(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
ConvBN(hidden_channels, 4, 1),
|
Conv1dBN(hidden_channels, hidden_channels, 4, 1),
|
||||||
ConvBN(hidden_channels, 3, 1),
|
Conv1dBN(hidden_channels, hidden_channels, 3, 1),
|
||||||
ConvBN(hidden_channels, 1, 1),
|
Conv1dBN(hidden_channels, hidden_channels, 1, 1),
|
||||||
nn.Conv1d(hidden_channels, 1, 1)
|
nn.Conv1d(hidden_channels, 1, 1)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
|
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
|
@ -18,12 +17,13 @@ class PositionalEncoding(nn.Module):
|
||||||
def __init__(self, channels, dropout=0.0, max_len=5000):
|
def __init__(self, channels, dropout=0.0, max_len=5000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if channels % 2 != 0:
|
if channels % 2 != 0:
|
||||||
raise ValueError("Cannot use sin/cos positional encoding with "
|
raise ValueError(
|
||||||
"odd channels (got channels={:d})".format(channels))
|
"Cannot use sin/cos positional encoding with "
|
||||||
|
"odd channels (got channels={:d})".format(channels))
|
||||||
pe = torch.zeros(max_len, channels)
|
pe = torch.zeros(max_len, channels)
|
||||||
position = torch.arange(0, max_len).unsqueeze(1)
|
position = torch.arange(0, max_len).unsqueeze(1)
|
||||||
div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
|
div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
|
||||||
-(math.log(10000.0) / channels)))
|
-(math.log(10000.0) / channels)))
|
||||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||||
pe = pe.unsqueeze(0).transpose(1, 2)
|
pe = pe.unsqueeze(0).transpose(1, 2)
|
||||||
|
@ -59,9 +59,77 @@ class PositionalEncoding(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionTransformerEncoder(nn.Module):
|
||||||
|
"""Speedy speech encoder built on Transformer with Relative Position encoding.
|
||||||
|
|
||||||
|
TODO: Integrate speaker conditioning vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels.
|
||||||
|
out_channels (int): number of output channels.
|
||||||
|
hidden_channels (int): number of hidden channels
|
||||||
|
params (dict): dictionary for residual convolutional blocks.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||||
|
super().__init__()
|
||||||
|
self.prenet = ResidualConv1dBNBlock(in_channels,
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size=5,
|
||||||
|
num_res_blocks=3,
|
||||||
|
num_conv_blocks=1,
|
||||||
|
dilations=[1, 1, 1]
|
||||||
|
)
|
||||||
|
self.rel_pos_transformer = RelativePositionTransformer(
|
||||||
|
hidden_channels, out_channels, hidden_channels, **params)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||||
|
if x_mask is None:
|
||||||
|
x_mask = 1
|
||||||
|
o = self.prenet(x) * x_mask
|
||||||
|
o = self.rel_pos_transformer(o, x_mask)
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConv1dBNEncoder(nn.Module):
|
||||||
|
"""Residual Convolutional Encoder as in the original Speedy Speech paper
|
||||||
|
|
||||||
|
TODO: Integrate speaker conditioning vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels.
|
||||||
|
out_channels (int): number of output channels.
|
||||||
|
hidden_channels (int): number of hidden channels
|
||||||
|
params (dict): dictionary for residual convolutional blocks.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||||
|
super().__init__()
|
||||||
|
self.prenet = nn.Sequential(
|
||||||
|
nn.Conv1d(in_channels, hidden_channels, 1),
|
||||||
|
nn.ReLU())
|
||||||
|
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels, **params)
|
||||||
|
|
||||||
|
self.postnet = nn.Sequential(*[
|
||||||
|
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.BatchNorm1d(hidden_channels),
|
||||||
|
nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||||
|
if x_mask is None:
|
||||||
|
x_mask = 1
|
||||||
|
o = self.prenet(x) * x_mask
|
||||||
|
o = self.res_conv_block(o, x_mask)
|
||||||
|
o = self.postnet(o + x) * x_mask
|
||||||
|
return o * x_mask
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
"""Speedy-Speech encoder using Transformers or Residual BN Convs internally.
|
"""Factory class for Speedy Speech encoder enables different encoder types internally.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_chars (int): number of characters.
|
num_chars (int): number of characters.
|
||||||
|
@ -114,29 +182,21 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
# init encoder
|
# init encoder
|
||||||
if encoder_type.lower() == "transformer":
|
if encoder_type.lower() == "transformer":
|
||||||
# optional convolutional prenet
|
|
||||||
self.pre = ConvLayerNorm(self.in_channels,
|
|
||||||
self.hidden_channels,
|
|
||||||
self.hidden_channels,
|
|
||||||
kernel_size=5,
|
|
||||||
num_layers=3,
|
|
||||||
dropout_p=0.5)
|
|
||||||
# text encoder
|
# text encoder
|
||||||
self.encoder = RelativePositionTransformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
self.encoder = RelativePositionTransformerEncoder(in_hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
in_hidden_channels,
|
||||||
|
encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||||
elif encoder_type.lower() == 'residual_conv_bn':
|
elif encoder_type.lower() == 'residual_conv_bn':
|
||||||
self.pre = nn.Sequential(
|
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
|
||||||
nn.Conv1d(self.in_channels, self.hidden_channels, 1),
|
out_channels,
|
||||||
nn.ReLU())
|
in_hidden_channels,
|
||||||
self.encoder = ResidualConvBNBlock(self.hidden_channels,
|
encoder_params)
|
||||||
**encoder_params)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(' [!] encoder type not implemented.')
|
raise NotImplementedError(' [!] unknown encoder type.')
|
||||||
|
|
||||||
# final projection layers
|
# final projection layers
|
||||||
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels,
|
|
||||||
1)
|
|
||||||
self.post_bn = nn.BatchNorm1d(self.hidden_channels)
|
|
||||||
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
|
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
|
@ -145,15 +205,5 @@ class Encoder(nn.Module):
|
||||||
x_mask: [B, 1, T]
|
x_mask: [B, 1, T]
|
||||||
g: [B, C, 1]
|
g: [B, C, 1]
|
||||||
"""
|
"""
|
||||||
# TODO: implement multi-speaker
|
o = self.encoder(x, x_mask)
|
||||||
if self.encoder_type == 'transformer':
|
|
||||||
o = self.pre(x, x_mask)
|
|
||||||
else:
|
|
||||||
o = self.pre(x) * x_mask
|
|
||||||
o = self.encoder(o, x_mask)
|
|
||||||
o = self.post_conv(o + x)
|
|
||||||
o = F.relu(o)
|
|
||||||
o = self.post_bn(o)
|
|
||||||
o = self.post_conv2(o)
|
|
||||||
# [B, C, T]
|
|
||||||
return o * x_mask
|
return o * x_mask
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .common_layers import Prenet, init_attn
|
from .common_layers import Prenet
|
||||||
|
from .attentions import init_attn
|
||||||
|
|
||||||
|
|
||||||
class BatchNormConv1d(nn.Module):
|
class BatchNormConv1d(nn.Module):
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from .common_layers import init_attn, Prenet, Linear
|
from .common_layers import Prenet, Linear
|
||||||
|
from .attentions import init_attn
|
||||||
|
|
||||||
# NOTE: linter has a problem with the current TF release
|
# NOTE: linter has a problem with the current TF release
|
||||||
#pylint: disable=no-value-for-parameter
|
#pylint: disable=no-value-for-parameter
|
||||||
|
|
|
@ -18,7 +18,7 @@ class Tacotron(TacotronAbstract):
|
||||||
r (int): initial model reduction rate.
|
r (int): initial model reduction rate.
|
||||||
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
|
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
|
||||||
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
|
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
|
||||||
attn_type (str, optional): attention type. Check ```TTS.tts.layers.common_layers.init_attn```. Defaults to 'original'.
|
attn_type (str, optional): attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'.
|
||||||
attn_win (bool, optional): enable/disable attention windowing.
|
attn_win (bool, optional): enable/disable attention windowing.
|
||||||
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
||||||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||||
|
|
|
@ -2,7 +2,7 @@ import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from TTS.tts.tf.utils.tf_utils import shape_list
|
from TTS.tts.tf.utils.tf_utils import shape_list
|
||||||
from TTS.tts.tf.layers.common_layers import Prenet, Attention
|
from TTS.tts.tf.layers.common_layers import Prenet, Attention
|
||||||
# from tensorflow_addons.seq2seq import AttentionWrapper
|
|
||||||
|
|
||||||
# NOTE: linter has a problem with the current TF release
|
# NOTE: linter has a problem with the current TF release
|
||||||
#pylint: disable=no-value-for-parameter
|
#pylint: disable=no-value-for-parameter
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||||
"clip_norm": true, // clip normalized values into the range.
|
"clip_norm": true, // clip normalized values into the range.
|
||||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||||
},
|
},
|
||||||
|
|
||||||
// VOCABULARY PARAMETERS
|
// VOCABULARY PARAMETERS
|
||||||
|
|
|
@ -50,10 +50,44 @@ def test_decoder():
|
||||||
input_mask = torch.unsqueeze(
|
input_mask = torch.unsqueeze(
|
||||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||||
|
|
||||||
|
# residual bn conv decoder
|
||||||
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
|
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
assert list(output.shape) == [8, 11, 37]
|
assert list(output.shape) == [8, 11, 37]
|
||||||
|
|
||||||
|
# transformer decoder
|
||||||
|
layer = Decoder(out_channels=11,
|
||||||
|
in_hidden_channels=128,
|
||||||
|
decoder_type='transformer',
|
||||||
|
decoder_params={
|
||||||
|
'hidden_channels_ffn': 128,
|
||||||
|
'num_heads': 2,
|
||||||
|
"kernel_size": 3,
|
||||||
|
"dropout_p": 0.1,
|
||||||
|
"num_layers": 8,
|
||||||
|
"rel_attn_window_size": 4,
|
||||||
|
"input_length": None
|
||||||
|
}).to(device)
|
||||||
|
output = layer(input_dummy, input_mask)
|
||||||
|
assert list(output.shape) == [8, 11, 37]
|
||||||
|
|
||||||
|
|
||||||
|
# wavenet decoder
|
||||||
|
layer = Decoder(out_channels=11,
|
||||||
|
in_hidden_channels=128,
|
||||||
|
decoder_type='wavenet',
|
||||||
|
decoder_params={
|
||||||
|
"num_blocks": 12,
|
||||||
|
"hidden_channels": 192,
|
||||||
|
"kernel_size": 5,
|
||||||
|
"dilation_rate": 1,
|
||||||
|
"num_layers": 4,
|
||||||
|
"dropout_p": 0.05
|
||||||
|
}).to(device)
|
||||||
|
output = layer(input_dummy, input_mask)
|
||||||
|
assert list(output.shape) == [8, 11, 37]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_duration_predictor():
|
def test_duration_predictor():
|
||||||
input_dummy = torch.rand(8, 128, 27).to(device)
|
input_dummy = torch.rand(8, 128, 27).to(device)
|
||||||
|
|
Loading…
Reference in New Issue