diff --git a/layers/common_layers.py b/layers/common_layers.py index 592f017c..8b7ed125 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -164,6 +164,9 @@ class GravesAttention(nn.Module): b_t = gbk_t[:, 1, :] k_t = gbk_t[:, 2, :] + # dropout to decorrelate attention heads + g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training) + # attention GMM parameters sig_t = torch.nn.functional.softplus(b_t) + self.eps