From 67df3852754cd167ea0b726ed103d72e82f5ad62 Mon Sep 17 00:00:00 2001 From: Eren Date: Wed, 19 Sep 2018 14:20:02 +0200 Subject: [PATCH] Explicit padding for unbalanced padding sizes --- layers/attention.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index 082b7127..aa5c94ce 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -43,13 +43,15 @@ class LocationSensitiveAttention(nn.Module): self.kernel_size = kernel_size self.filters = filters padding = [(kernel_size - 1) // 2, (kernel_size - 1) // 2] - self.loc_conv = nn.Conv1d( - 2, - filters, - kernel_size=kernel_size, - stride=1, - padding=padding, - bias=False) + self.loc_conv = nn.Sequential( + nn.ConstantPad1d(padding, 0), + nn.Conv1d( + 2, + filters, + kernel_size=kernel_size, + stride=1, + padding=0, + bias=False)) self.loc_linear = nn.Linear(filters, attn_dim) self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) @@ -100,8 +102,8 @@ class AttentionRNNCell(nn.Module): annot_dim, rnn_dim, out_dim) else: raise RuntimeError(" Wrong alignment model name: {}. Use\ - 'b' (Bahdanau) or 'ls' (Location Sensitive)." - .format(align_model)) + 'b' (Bahdanau) or 'ls' (Location Sensitive).".format( + align_model)) def forward(self, memory, context, rnn_state, annots, atten, mask): """