diff --git a/layers/common_layers.py b/layers/common_layers.py index 023c7404..592f017c 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -111,8 +111,9 @@ class LocationLayer(nn.Module): class GravesAttention(nn.Module): - """ Graves attention as described here: + """ Discretized Graves attention: - https://arxiv.org/abs/1910.10288 + - https://arxiv.org/pdf/1906.01083.pdf """ COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) @@ -368,82 +369,6 @@ class OriginalAttention(nn.Module): return context -class GravesAttention(nn.Module): - """ Graves attention as described here: - - https://arxiv.org/abs/1910.10288 - """ - COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) - - def __init__(self, query_dim, K): - super(GravesAttention, self).__init__() - self._mask_value = 0.0 - self.K = K - # self.attention_alignment = 0.05 - self.eps = 1e-5 - self.J = None - self.N_a = nn.Sequential( - nn.Linear(query_dim, query_dim, bias=True), - nn.ReLU(), - nn.Linear(query_dim, 3*K, bias=True)) - self.attention_weights = None - self.mu_prev = None - self.init_layers() - - def init_layers(self): - torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) - torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) - - def init_states(self, inputs): - if self.J is None or inputs.shape[1] > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]).to(inputs.device) - self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) - self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) - - # pylint: disable=R0201 - # pylint: disable=unused-argument - def preprocess_inputs(self, inputs): - return None - - def forward(self, query, inputs, processed_inputs, mask): - """ - shapes: - query: B x D_attention_rnn - inputs: B x T_in x D_encoder - processed_inputs: place_holder - mask: B x T_in - """ - gbk_t = self.N_a(query) - gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) - - # attention model parameters - # each B x K - g_t = gbk_t[:, 0, :] - b_t = gbk_t[:, 1, :] - k_t = gbk_t[:, 2, :] - - # attention GMM parameters - sig_t = torch.nn.functional.softplus(b_t) + self.eps - - mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) - g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps - - # each B x K x T_in - j = self.J[:inputs.size(1)] - - # attention weights - phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2)) - alpha_t = self.COEF * torch.sum(phi_t, 1) - - # apply masking - if mask is not None: - alpha_t.data.masked_fill_(~mask, self._mask_value) - - context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) - self.attention_weights = alpha_t - self.mu_prev = mu_t - return context - - def init_attn(attn_type, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn,