diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 9b7ecee2..29f3c888 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -36,7 +36,7 @@ class FFTransformer(nn.Module): class FFTransformerBlock(nn.Module): - def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p, kernel_size_fft): super().__init__() self.fft_layers = nn.ModuleList( [ @@ -45,6 +45,7 @@ class FFTransformerBlock(nn.Module): num_heads=num_heads, hidden_channels_ffn=hidden_channels_ffn, dropout_p=dropout_p, + kernel_size_fft=kernel_size_fft, ) for _ in range(num_layers) ] @@ -71,9 +72,16 @@ class FFTransformerBlock(nn.Module): class FFTDurationPredictor: def __init__( - self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None + self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None, kernel_size_fft=3 ): # pylint: disable=unused-argument - self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.fft = FFTransformerBlock( + in_out_channels=in_channels, + num_heads=num_heads, + hidden_channels=hidden_channels, + num_layers=num_layers, + dropout_p=dropout_p, + kernel_size_fft=kernel_size_fft, + ) self.proj = nn.Linear(in_channels, 1) def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument