mirror of https://github.com/coqui-ai/TTS.git
Add missing kernel size attr to transformer layer
This commit is contained in:
parent
231c69b12e
commit
e7c5db0d97
TTS/tts/layers/generic
|
@ -36,7 +36,7 @@ class FFTransformer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FFTransformerBlock(nn.Module):
|
class FFTransformerBlock(nn.Module):
|
||||||
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p):
|
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p, kernel_size_fft):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fft_layers = nn.ModuleList(
|
self.fft_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -45,6 +45,7 @@ class FFTransformerBlock(nn.Module):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
hidden_channels_ffn=hidden_channels_ffn,
|
hidden_channels_ffn=hidden_channels_ffn,
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
|
kernel_size_fft=kernel_size_fft,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -71,9 +72,16 @@ class FFTransformerBlock(nn.Module):
|
||||||
|
|
||||||
class FFTDurationPredictor:
|
class FFTDurationPredictor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None
|
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None, kernel_size_fft=3
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
self.fft = FFTransformerBlock(
|
||||||
|
in_out_channels=in_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
kernel_size_fft=kernel_size_fft,
|
||||||
|
)
|
||||||
self.proj = nn.Linear(in_channels, 1)
|
self.proj = nn.Linear(in_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
|
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
|
||||||
|
|
Loading…
Reference in New Issue