FFTransformer Decoder for AlignTTS

This commit is contained in:
Eren Gölge 2021-03-16 17:05:15 +01:00
parent 2c364c0df8
commit 460a2d3e26
1 changed files with 36 additions and 3 deletions

View File

@ -3,6 +3,7 @@ from torch import nn
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
from TTS.tts.layers.generic.wavenet import WNBlocks
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.transformer import FFTransformersBlock
class WaveNetDecoder(nn.Module):
@ -89,6 +90,36 @@ class RelativePositionTransformerDecoder(nn.Module):
return o
class FFTransformerDecoder(nn.Module):
"""Decoder with FeedForwardTransformer.
Note:
Default params
params={
'hidden_channels_ffn': 1024,
'num_heads': 2,
"dropout_p": 0.1,
"num_layers": 6,
}
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
hidden_channels (int): number of hidden channels including Transformer layers.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, params):
super().__init__()
self.transformer_block = FFTransformersBlock(in_channels, **params)
self.postnet = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
# TODO: handle multi-speaker
o = self.transformer_block(x) * x_mask
o = self.postnet(o)* x_mask
return o
class ResidualConv1dBNDecoder(nn.Module):
"""Residual Convolutional Decoder as in the original Speedy Speech paper
@ -159,24 +190,26 @@ class Decoder(nn.Module):
c_in_channels=0):
super().__init__()
if decoder_type == 'transformer':
if decoder_type.lower() == "relative_position_transformer":
self.decoder = RelativePositionTransformerDecoder(
in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
params=decoder_params)
elif decoder_type == 'residual_conv_bn':
elif decoder_type.lower() == 'residual_conv_bn':
self.decoder = ResidualConv1dBNDecoder(
in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
params=decoder_params)
elif decoder_type == 'wavenet':
elif decoder_type.lower() == 'wavenet':
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
c_in_channels=c_in_channels,
params=decoder_params)
elif decoder_type.lower() == 'transformer':
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
else:
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')