mirror of https://github.com/coqui-ai/TTS.git
FFTransformer Decoder for AlignTTS
This commit is contained in:
parent
2c364c0df8
commit
460a2d3e26
|
@ -3,6 +3,7 @@ from torch import nn
|
||||||
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
|
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
|
||||||
from TTS.tts.layers.generic.wavenet import WNBlocks
|
from TTS.tts.layers.generic.wavenet import WNBlocks
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
|
from TTS.tts.layers.generic.transformer import FFTransformersBlock
|
||||||
|
|
||||||
|
|
||||||
class WaveNetDecoder(nn.Module):
|
class WaveNetDecoder(nn.Module):
|
||||||
|
@ -89,6 +90,36 @@ class RelativePositionTransformerDecoder(nn.Module):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class FFTransformerDecoder(nn.Module):
|
||||||
|
"""Decoder with FeedForwardTransformer.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Default params
|
||||||
|
params={
|
||||||
|
'hidden_channels_ffn': 1024,
|
||||||
|
'num_heads': 2,
|
||||||
|
"dropout_p": 0.1,
|
||||||
|
"num_layers": 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input channels.
|
||||||
|
out_channels (int): number of output channels.
|
||||||
|
hidden_channels (int): number of hidden channels including Transformer layers.
|
||||||
|
params (dict): dictionary for residual convolutional blocks.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, params):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.transformer_block = FFTransformersBlock(in_channels, **params)
|
||||||
|
self.postnet = nn.Conv1d(in_channels, out_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||||
|
# TODO: handle multi-speaker
|
||||||
|
o = self.transformer_block(x) * x_mask
|
||||||
|
o = self.postnet(o)* x_mask
|
||||||
|
return o
|
||||||
|
|
||||||
class ResidualConv1dBNDecoder(nn.Module):
|
class ResidualConv1dBNDecoder(nn.Module):
|
||||||
"""Residual Convolutional Decoder as in the original Speedy Speech paper
|
"""Residual Convolutional Decoder as in the original Speedy Speech paper
|
||||||
|
|
||||||
|
@ -159,24 +190,26 @@ class Decoder(nn.Module):
|
||||||
c_in_channels=0):
|
c_in_channels=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if decoder_type == 'transformer':
|
if decoder_type.lower() == "relative_position_transformer":
|
||||||
self.decoder = RelativePositionTransformerDecoder(
|
self.decoder = RelativePositionTransformerDecoder(
|
||||||
in_channels=in_hidden_channels,
|
in_channels=in_hidden_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
hidden_channels=in_hidden_channels,
|
hidden_channels=in_hidden_channels,
|
||||||
params=decoder_params)
|
params=decoder_params)
|
||||||
elif decoder_type == 'residual_conv_bn':
|
elif decoder_type.lower() == 'residual_conv_bn':
|
||||||
self.decoder = ResidualConv1dBNDecoder(
|
self.decoder = ResidualConv1dBNDecoder(
|
||||||
in_channels=in_hidden_channels,
|
in_channels=in_hidden_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
hidden_channels=in_hidden_channels,
|
hidden_channels=in_hidden_channels,
|
||||||
params=decoder_params)
|
params=decoder_params)
|
||||||
elif decoder_type == 'wavenet':
|
elif decoder_type.lower() == 'wavenet':
|
||||||
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
|
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
hidden_channels=in_hidden_channels,
|
hidden_channels=in_hidden_channels,
|
||||||
c_in_channels=c_in_channels,
|
c_in_channels=c_in_channels,
|
||||||
params=decoder_params)
|
params=decoder_params)
|
||||||
|
elif decoder_type.lower() == 'transformer':
|
||||||
|
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue