mirror of https://github.com/coqui-ai/TTS.git
bug fix for illegal memory reach
This commit is contained in:
parent
b9e0faca98
commit
695bf1a1f6
|
@ -127,7 +127,7 @@ class GravesAttention(nn.Module):
|
||||||
|
|
||||||
def init_states(self, inputs):
|
def init_states(self, inputs):
|
||||||
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
||||||
self.J = torch.arange(0, inputs.shape[1]).expand_as(torch.Tensor(inputs.shape[0], self.K, inputs.shape[1])).to(inputs.device)
|
self.J = torch.arange(0, inputs.shape[1]).to(inputs.device).expand([inputs.shape[0], self.K, inputs.shape[1]])
|
||||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue