mirror of https://github.com/coqui-ai/TTS.git
Fix trans agent implementation in relation to the paper. Use query vector insteadd of processed_query
This commit is contained in:
parent
5e679f746d
commit
2b60f9a731
|
@ -131,7 +131,7 @@ class Attention(nn.Module):
|
|||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.v = Linear(attention_dim, 1, bias=True)
|
||||
if trans_agent:
|
||||
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True)
|
||||
self.ta = nn.Linear(attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||
if location_attention:
|
||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
|
@ -208,7 +208,7 @@ class Attention(nn.Module):
|
|||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
return attention
|
||||
|
||||
def apply_forward_attention(self, inputs, alignment, processed_query):
|
||||
def apply_forward_attention(self, inputs, alignment, query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment
|
||||
|
@ -218,7 +218,7 @@ class Attention(nn.Module):
|
|||
context = context.squeeze(1)
|
||||
# compute transition agent
|
||||
if self.trans_agent:
|
||||
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, self.alpha
|
||||
|
||||
|
@ -248,7 +248,7 @@ class Attention(nn.Module):
|
|||
self.update_location_attention(alignment)
|
||||
# apply forward attention if enabled
|
||||
if self.forward_attn:
|
||||
context, self.attention_weights = self.apply_forward_attention(inputs, alignment, processed_query)
|
||||
context, self.attention_weights = self.apply_forward_attention(inputs, alignment, attention_hidden_state)
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
|
|
Loading…
Reference in New Issue