Fix trans agent implementation in relation to the paper. Use query vector insteadd of processed_query

This commit is contained in:
Eren Golge 2019-05-12 17:39:12 +02:00
parent 5e679f746d
commit 2b60f9a731
1 changed files with 4 additions and 4 deletions

View File

@ -131,7 +131,7 @@ class Attention(nn.Module):
embedding_dim, attention_dim, bias=False, init_gain='tanh') embedding_dim, attention_dim, bias=False, init_gain='tanh')
self.v = Linear(attention_dim, 1, bias=True) self.v = Linear(attention_dim, 1, bias=True)
if trans_agent: if trans_agent:
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True) self.ta = nn.Linear(attention_rnn_dim + embedding_dim, 1, bias=True)
if location_attention: if location_attention:
self.location_layer = LocationLayer(attention_location_n_filters, self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size, attention_location_kernel_size,
@ -208,7 +208,7 @@ class Attention(nn.Module):
self.win_idx = torch.argmax(attention, 1).long()[0].item() self.win_idx = torch.argmax(attention, 1).long()[0].item()
return attention return attention
def apply_forward_attention(self, inputs, alignment, processed_query): def apply_forward_attention(self, inputs, alignment, query):
# forward attention # forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment
@ -218,7 +218,7 @@ class Attention(nn.Module):
context = context.squeeze(1) context = context.squeeze(1)
# compute transition agent # compute transition agent
if self.trans_agent: if self.trans_agent:
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1) ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input)) self.u = torch.sigmoid(self.ta(ta_input))
return context, self.alpha return context, self.alpha
@ -248,7 +248,7 @@ class Attention(nn.Module):
self.update_location_attention(alignment) self.update_location_attention(alignment)
# apply forward attention if enabled # apply forward attention if enabled
if self.forward_attn: if self.forward_attn:
context, self.attention_weights = self.apply_forward_attention(inputs, alignment, processed_query) context, self.attention_weights = self.apply_forward_attention(inputs, alignment, attention_hidden_state)
else: else:
context = torch.bmm(alignment.unsqueeze(1), inputs) context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1) context = context.squeeze(1)