mirror of https://github.com/coqui-ai/TTS.git
fix forward attention
This commit is contained in:
parent
01dbfb3a0f
commit
9ba13b2d2f
|
@ -152,7 +152,7 @@ class Attention(nn.Module):
|
||||||
"""
|
"""
|
||||||
B = inputs.shape[0]
|
B = inputs.shape[0]
|
||||||
T = inputs.shape[1]
|
T = inputs.shape[1]
|
||||||
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1]], dim=1).to(inputs.device)
|
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device)
|
||||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||||
|
|
||||||
def get_attention(self, query, processed_inputs, attention_cat):
|
def get_attention(self, query, processed_inputs, attention_cat):
|
||||||
|
@ -183,16 +183,16 @@ class Attention(nn.Module):
|
||||||
def apply_forward_attention(self, inputs, alignment, processed_query):
|
def apply_forward_attention(self, inputs, alignment, processed_query):
|
||||||
# forward attention
|
# forward attention
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||||
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
|
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha)) * alignment
|
||||||
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
|
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||||
# compute context
|
# compute context
|
||||||
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
|
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
# compute transition agent
|
# compute transition agent
|
||||||
if self.trans_agent:
|
if self.trans_agent:
|
||||||
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
||||||
self.u = torch.sigmoid(self.ta(ta_input))
|
self.u = torch.sigmoid(self.ta(ta_input))
|
||||||
return context, alpha_norm, alignment
|
return context, self.alpha, alignment
|
||||||
|
|
||||||
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
||||||
attention_cat, mask):
|
attention_cat, mask):
|
||||||
|
|
Loading…
Reference in New Issue