attentio update

This commit is contained in:
Eren G 2018-07-17 16:24:39 +02:00
parent d4f1ccd3ed
commit ddaf414434
2 changed files with 13 additions and 12 deletions

View File

@ -5,11 +5,11 @@ from utils.generic_utils import sequence_mask
class BahdanauAttention(nn.Module):
def __init__(self, annot_dim, query_dim, hidden_dim):
def __init__(self, annot_dim, query_dim, attn_dim):
super(BahdanauAttention, self).__init__()
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
self.v = nn.Linear(hidden_dim, 1, bias=False)
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(attn_dim, 1, bias=False)
def forward(self, annots, query):
"""
@ -33,8 +33,8 @@ class BahdanauAttention(nn.Module):
class LocationSensitiveAttention(nn.Module):
"""Location sensitive attention following
https://arxiv.org/pdf/1506.07503.pdf"""
def __init__(self, annot_dim, query_dim, hidden_dim,
kernel_size=7, filters=20):
def __init__(self, annot_dim, query_dim, attn_dim,
kernel_size=31, filters=32):
super(LocationSensitiveAttention, self).__init__()
self.kernel_size = kernel_size
self.filters = filters
@ -42,10 +42,10 @@ class LocationSensitiveAttention(nn.Module):
self.loc_conv = nn.Conv1d(1, filters,
kernel_size=kernel_size, stride=1,
padding=padding, bias=False)
self.loc_linear = nn.Linear(filters, hidden_dim)
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
self.v = nn.Linear(hidden_dim, 1, bias=False)
self.loc_linear = nn.Linear(filters, attn_dim)
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(attn_dim, 1, bias=False)
def forward(self, annot, query, loc):
"""
@ -104,6 +104,7 @@ class AttentionRNNCell(nn.Module):
- annot_lens: (batch,)
"""
# Concat input query and previous context context
print(context.shape)
rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1})

View File

@ -203,7 +203,7 @@ class Decoder(nn.Module):
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell(256, in_features, 128, align_model='ls')
self.attention_rnn = AttentionRNNCell(128, in_features, 128, align_model='ls')
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
@ -248,7 +248,7 @@ class Decoder(nn.Module):
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 256).zero_()
current_context_vec = inputs.data.new(B, 128).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()