From f2ef1ca36afd4eb74bd1ad36a7a84411f69e5435 Mon Sep 17 00:00:00 2001 From: Eren Date: Wed, 19 Sep 2018 15:08:43 +0200 Subject: [PATCH] Smmothed attention as in https://arxiv.org/pdf/1506.07503.pdf --- layers/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/layers/attention.py b/layers/attention.py index aa5c94ce..bbf83055 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -131,7 +131,8 @@ class AttentionRNNCell(nn.Module): mask = mask.view(memory.size(0), -1) alignment.masked_fill_(1 - mask, -float("inf")) # Normalize context weight - alignment = F.softmax(alignment, dim=-1) + # alignment = F.softmax(alignment, dim=-1) + alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j