mirror of https://github.com/coqui-ai/TTS.git
select decoder type for SS
This commit is contained in:
parent
d5a0190c4b
commit
4ef083f0f1
|
@ -84,7 +84,7 @@ class WN(torch.nn.Module):
|
|||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
if c_in_channels != 0:
|
||||
if c_in_channels > 0:
|
||||
cond_layer = torch.nn.Conv1d(c_in_channels,
|
||||
2 * hidden_channels * num_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
|
||||
|
|
|
@ -1,32 +1,72 @@
|
|||
from torch import nn
|
||||
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Decodes the expanded phoneme encoding into spectrograms
|
||||
Args:
|
||||
out_channels (int): number of output channels.
|
||||
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
|
||||
decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
|
||||
decoder_params (dict): model parameters for specified decoder type.
|
||||
c_in_channels (int): number of channels for conditional input.
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
|
||||
Note:
|
||||
Default decoder_params...
|
||||
|
||||
for 'transformer'
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 8,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
},
|
||||
|
||||
for 'residual_conv_bn'
|
||||
encoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
}
|
||||
"""
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
residual_conv_bn_params={
|
||||
in_hidden_channels,
|
||||
decoder_type='residual_conv_bn',
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
}):
|
||||
},
|
||||
c_in_channels=0):
|
||||
super().__init__()
|
||||
self.in_channels = in_hidden_channels
|
||||
self.hidden_channels = in_hidden_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.decoder = ResidualConvBNBlock(hidden_channels,
|
||||
**residual_conv_bn_params)
|
||||
if decoder_type == 'transformer':
|
||||
self.decoder = Transformer(self.hidden_channels, **decoder_params)
|
||||
elif decoder_type == 'residual_conv_bn':
|
||||
self.decoder = ResidualConvBNBlock(self.hidden_channels,
|
||||
**decoder_params)
|
||||
else:
|
||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||
|
||||
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, 1)
|
||||
self.post_net = nn.Sequential(
|
||||
ConvBNBlock(hidden_channels, residual_conv_bn_params['kernel_size'], 1, num_conv_blocks=2),
|
||||
nn.Conv1d(hidden_channels, out_channels, 1),
|
||||
ConvBNBlock(self.hidden_channels, decoder_params['kernel_size'], 1, num_conv_blocks=2),
|
||||
nn.Conv1d(self.hidden_channels, out_channels, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||
|
|
Loading…
Reference in New Issue