Add missing kernel size attr to transformer layer

This commit is contained in:
Eren Gölge 2022-04-19 09:19:57 +02:00 committed by Eren G??lge
parent 231c69b12e
commit e7c5db0d97
1 changed files with 11 additions and 3 deletions

View File

@ -36,7 +36,7 @@ class FFTransformer(nn.Module):
class FFTransformerBlock(nn.Module):
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p):
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p, kernel_size_fft):
super().__init__()
self.fft_layers = nn.ModuleList(
[
@ -45,6 +45,7 @@ class FFTransformerBlock(nn.Module):
num_heads=num_heads,
hidden_channels_ffn=hidden_channels_ffn,
dropout_p=dropout_p,
kernel_size_fft=kernel_size_fft,
)
for _ in range(num_layers)
]
@ -71,9 +72,16 @@ class FFTransformerBlock(nn.Module):
class FFTDurationPredictor:
def __init__(
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None, kernel_size_fft=3
): # pylint: disable=unused-argument
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.fft = FFTransformerBlock(
in_out_channels=in_channels,
num_heads=num_heads,
hidden_channels=hidden_channels,
num_layers=num_layers,
dropout_p=dropout_p,
kernel_size_fft=kernel_size_fft,
)
self.proj = nn.Linear(in_channels, 1)
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument