mirror of https://github.com/coqui-ai/TTS.git
commit
7292d303b9
|
@ -138,7 +138,7 @@ class GravesAttention(nn.Module):
|
|||
|
||||
def init_states(self, inputs):
|
||||
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
||||
self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5
|
||||
self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
|
||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue