diff --git a/layers/attention.py b/layers/attention.py index aa5c94ce..bbf83055 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -131,7 +131,8 @@ class AttentionRNNCell(nn.Module): mask = mask.view(memory.size(0), -1) alignment.masked_fill_(1 - mask, -float("inf")) # Normalize context weight - alignment = F.softmax(alignment, dim=-1) + # alignment = F.softmax(alignment, dim=-1) + alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j