diff --git a/layers/common_layers.py b/layers/common_layers.py index 8b7ed125..78fa8b1c 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -138,7 +138,7 @@ class GravesAttention(nn.Module): def init_states(self, inputs): if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5 + self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5 self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)