diff --git a/layers/common_layers.py b/layers/common_layers.py index 66ffcd1c..52be8bfd 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -127,8 +127,8 @@ class GravesAttention(nn.Module): self.init_layers() def init_layers(self): - torch.nn.init.constant_(self.N_a[2].bias[10:15], 0.5) - torch.nn.init.constant_(self.N_a[2].bias[5:10], 10) + torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) + torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) def init_states(self, inputs): if self.J is None or inputs.shape[1] > self.J.shape[-1]: