diff --git a/layers/common_layers.py b/layers/common_layers.py index 8ea54f0e..6ddd0b6b 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -127,7 +127,7 @@ class GravesAttention(nn.Module): def init_states(self, inputs): if self.J is None or inputs.shape[1] > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]).expand_as(torch.Tensor(inputs.shape[0], self.K, inputs.shape[1])).to(inputs.device) + self.J = torch.arange(0, inputs.shape[1]).to(inputs.device).expand([inputs.shape[0], self.K, inputs.shape[1]]) self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)