mirror of https://github.com/coqui-ai/TTS.git
FFTransformer Decoder for AlignTTS
This commit is contained in:
parent
2c364c0df8
commit
460a2d3e26
|
@ -3,6 +3,7 @@ from torch import nn
|
|||
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
|
||||
from TTS.tts.layers.generic.wavenet import WNBlocks
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.generic.transformer import FFTransformersBlock
|
||||
|
||||
|
||||
class WaveNetDecoder(nn.Module):
|
||||
|
@ -89,6 +90,36 @@ class RelativePositionTransformerDecoder(nn.Module):
|
|||
return o
|
||||
|
||||
|
||||
class FFTransformerDecoder(nn.Module):
|
||||
"""Decoder with FeedForwardTransformer.
|
||||
|
||||
Note:
|
||||
Default params
|
||||
params={
|
||||
'hidden_channels_ffn': 1024,
|
||||
'num_heads': 2,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
}
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): number of hidden channels including Transformer layers.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, params):
|
||||
|
||||
super().__init__()
|
||||
self.transformer_block = FFTransformersBlock(in_channels, **params)
|
||||
self.postnet = nn.Conv1d(in_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
# TODO: handle multi-speaker
|
||||
o = self.transformer_block(x) * x_mask
|
||||
o = self.postnet(o)* x_mask
|
||||
return o
|
||||
|
||||
class ResidualConv1dBNDecoder(nn.Module):
|
||||
"""Residual Convolutional Decoder as in the original Speedy Speech paper
|
||||
|
||||
|
@ -159,24 +190,26 @@ class Decoder(nn.Module):
|
|||
c_in_channels=0):
|
||||
super().__init__()
|
||||
|
||||
if decoder_type == 'transformer':
|
||||
if decoder_type.lower() == "relative_position_transformer":
|
||||
self.decoder = RelativePositionTransformerDecoder(
|
||||
in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type == 'residual_conv_bn':
|
||||
elif decoder_type.lower() == 'residual_conv_bn':
|
||||
self.decoder = ResidualConv1dBNDecoder(
|
||||
in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type == 'wavenet':
|
||||
elif decoder_type.lower() == 'wavenet':
|
||||
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
c_in_channels=c_in_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type.lower() == 'transformer':
|
||||
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
|
||||
else:
|
||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||
|
||||
|
|
Loading…
Reference in New Issue