mirror of https://github.com/coqui-ai/TTS.git
FFTransformer encoder for aligntts
This commit is contained in:
parent
460a2d3e26
commit
c2d29e5cd4
|
@ -4,59 +4,7 @@ from torch import nn
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||||
|
from TTS.tts.layers.generic.transformer import FFTransformersBlock
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
|
||||||
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
|
||||||
Implementation based on "Attention Is All You Need"
|
|
||||||
Args:
|
|
||||||
channels (int): embedding size
|
|
||||||
dropout (float): dropout parameter
|
|
||||||
"""
|
|
||||||
def __init__(self, channels, dropout=0.0, max_len=5000):
|
|
||||||
super().__init__()
|
|
||||||
if channels % 2 != 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot use sin/cos positional encoding with "
|
|
||||||
"odd channels (got channels={:d})".format(channels))
|
|
||||||
pe = torch.zeros(max_len, channels)
|
|
||||||
position = torch.arange(0, max_len).unsqueeze(1)
|
|
||||||
div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
|
|
||||||
-(math.log(10000.0) / channels)))
|
|
||||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
|
||||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
|
||||||
pe = pe.unsqueeze(0).transpose(1, 2)
|
|
||||||
self.register_buffer('pe', pe)
|
|
||||||
if dropout > 0:
|
|
||||||
self.dropout = nn.Dropout(p=dropout)
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
def forward(self, x, mask=None, first_idx=None, last_idx=None):
|
|
||||||
"""
|
|
||||||
Shapes:
|
|
||||||
x: [B, C, T]
|
|
||||||
mask: [B, 1, T]
|
|
||||||
first_idx: int
|
|
||||||
last_idx: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
x = x * math.sqrt(self.channels)
|
|
||||||
if first_idx is None:
|
|
||||||
if self.pe.size(2) < x.size(2):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
|
||||||
f" limited to {self.pe.size(2)}. See max_len argument.")
|
|
||||||
if mask is not None:
|
|
||||||
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
|
|
||||||
else:
|
|
||||||
pos_enc = self.pe[:, :, :x.size(2)]
|
|
||||||
x = x + pos_enc
|
|
||||||
else:
|
|
||||||
x = x + self.pe[:, :, first_idx:last_idx]
|
|
||||||
if hasattr(self, 'dropout'):
|
|
||||||
x = self.dropout(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionTransformerEncoder(nn.Module):
|
class RelativePositionTransformerEncoder(nn.Module):
|
||||||
|
@ -138,9 +86,9 @@ class Encoder(nn.Module):
|
||||||
c_in_channels (int): number of channels for conditional input.
|
c_in_channels (int): number of channels for conditional input.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Default encoder_params...
|
Default encoder_params to be set in config.json...
|
||||||
|
|
||||||
for 'transformer'
|
for 'relative_position_transformer'
|
||||||
encoder_params={
|
encoder_params={
|
||||||
'hidden_channels_ffn': 128,
|
'hidden_channels_ffn': 128,
|
||||||
'num_heads': 2,
|
'num_heads': 2,
|
||||||
|
@ -158,6 +106,14 @@ class Encoder(nn.Module):
|
||||||
"num_conv_blocks": 2,
|
"num_conv_blocks": 2,
|
||||||
"num_res_blocks": 13
|
"num_res_blocks": 13
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for 'transformer_decoder'
|
||||||
|
encoder_params = {
|
||||||
|
hidden_channels_ffn: 1024 ,
|
||||||
|
num_heads: 2,
|
||||||
|
num_layers: 6,
|
||||||
|
dropout_p: 0.1
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -179,7 +135,7 @@ class Encoder(nn.Module):
|
||||||
self.c_in_channels = c_in_channels
|
self.c_in_channels = c_in_channels
|
||||||
|
|
||||||
# init encoder
|
# init encoder
|
||||||
if encoder_type.lower() == "transformer":
|
if encoder_type.lower() == "relative_position_transformer":
|
||||||
# text encoder
|
# text encoder
|
||||||
self.encoder = RelativePositionTransformerEncoder(
|
self.encoder = RelativePositionTransformerEncoder(
|
||||||
in_hidden_channels, out_channels, in_hidden_channels,
|
in_hidden_channels, out_channels, in_hidden_channels,
|
||||||
|
@ -189,11 +145,11 @@ class Encoder(nn.Module):
|
||||||
out_channels,
|
out_channels,
|
||||||
in_hidden_channels,
|
in_hidden_channels,
|
||||||
encoder_params)
|
encoder_params)
|
||||||
|
elif encoder_type.lower() == 'transformer':
|
||||||
|
self.encoder = FFTransformersBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(' [!] unknown encoder type.')
|
raise NotImplementedError(' [!] unknown encoder type.')
|
||||||
|
|
||||||
# final projection layers
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue