mirror of https://github.com/coqui-ai/TTS.git
fix forward attention
This commit is contained in:
parent
01dbfb3a0f
commit
9ba13b2d2f
|
@ -152,7 +152,7 @@ class Attention(nn.Module):
|
|||
"""
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1]], dim=1).to(inputs.device)
|
||||
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device)
|
||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||
|
||||
def get_attention(self, query, processed_inputs, attention_cat):
|
||||
|
@ -183,16 +183,16 @@ class Attention(nn.Module):
|
|||
def apply_forward_attention(self, inputs, alignment, processed_query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
|
||||
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
|
||||
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha)) * alignment
|
||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
|
||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
# compute transition agent
|
||||
if self.trans_agent:
|
||||
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, alpha_norm, alignment
|
||||
return context, self.alpha, alignment
|
||||
|
||||
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
||||
attention_cat, mask):
|
||||
|
|
Loading…
Reference in New Issue