diff --git a/layers/attention.py b/layers/attention.py index 31cd03b6..8b445da8 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -34,7 +34,7 @@ class LocationSensitiveAttention(nn.Module): """Location sensitive attention following https://arxiv.org/pdf/1506.07503.pdf""" def __init__(self, annot_dim, query_dim, attn_dim, - kernel_size=31, filters=32): + kernel_size=7, filters=20): super(LocationSensitiveAttention, self).__init__() self.kernel_size = kernel_size self.filters = filters