dropout graves attention heads to decorrelate and prevent overpowering of a single head

This commit is contained in:
erogol 2020-03-10 13:53:04 +01:00
parent 975842f71a
commit 201f04d3b3
1 changed files with 3 additions and 0 deletions

View File

@ -164,6 +164,9 @@ class GravesAttention(nn.Module):
b_t = gbk_t[:, 1, :] b_t = gbk_t[:, 1, :]
k_t = gbk_t[:, 2, :] k_t = gbk_t[:, 2, :]
# dropout to decorrelate attention heads
g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training)
# attention GMM parameters # attention GMM parameters
sig_t = torch.nn.functional.softplus(b_t) + self.eps sig_t = torch.nn.functional.softplus(b_t) + self.eps