diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py index 6d32c914..eeccbe14 100644 --- a/TTS/tts/layers/feed_forward/decoder.py +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -3,6 +3,7 @@ from torch import nn from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN from TTS.tts.layers.generic.wavenet import WNBlocks from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.layers.generic.transformer import FFTransformersBlock class WaveNetDecoder(nn.Module): @@ -89,6 +90,36 @@ class RelativePositionTransformerDecoder(nn.Module): return o +class FFTransformerDecoder(nn.Module): + """Decoder with FeedForwardTransformer. + + Note: + Default params + params={ + 'hidden_channels_ffn': 1024, + 'num_heads': 2, + "dropout_p": 0.1, + "num_layers": 6, + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + def __init__(self, in_channels, out_channels, params): + + super().__init__() + self.transformer_block = FFTransformersBlock(in_channels, **params) + self.postnet = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + # TODO: handle multi-speaker + o = self.transformer_block(x) * x_mask + o = self.postnet(o)* x_mask + return o + class ResidualConv1dBNDecoder(nn.Module): """Residual Convolutional Decoder as in the original Speedy Speech paper @@ -159,24 +190,26 @@ class Decoder(nn.Module): c_in_channels=0): super().__init__() - if decoder_type == 'transformer': + if decoder_type.lower() == "relative_position_transformer": self.decoder = RelativePositionTransformerDecoder( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, params=decoder_params) - elif decoder_type == 'residual_conv_bn': + elif decoder_type.lower() == 'residual_conv_bn': self.decoder = ResidualConv1dBNDecoder( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, params=decoder_params) - elif decoder_type == 'wavenet': + elif decoder_type.lower() == 'wavenet': self.decoder = WaveNetDecoder(in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, c_in_channels=c_in_channels, params=decoder_params) + elif decoder_type.lower() == 'transformer': + self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) else: raise ValueError(f'[!] Unknown decoder type - {decoder_type}')