From 0e8881114b7cd223a41a452ea7cf570b56c109a7 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Jan 2020 13:45:09 +0100 Subject: [PATCH] efficient GMM attneiton with native broadcasting --- layers/common_layers.py | 162 ++++++++++++++++++++-------------------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index 8b8ff073..c2b042b0 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -82,6 +82,11 @@ class Prenet(nn.Module): return x +#################### +# ATTENTION MODULES +#################### + + class LocationLayer(nn.Module): def __init__(self, attention_dim, @@ -105,87 +110,6 @@ class LocationLayer(nn.Module): return processed_attention -class GravesAttention(nn.Module): - """ Graves attention as described here: - - https://arxiv.org/abs/1910.10288 - """ - COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) - - def __init__(self, query_dim, K): - super(GravesAttention, self).__init__() - self._mask_value = 0.0 - self.K = K - # self.attention_alignment = 0.05 - self.eps = 1e-5 - self.J = None - self.N_a = nn.Sequential( - nn.Linear(query_dim, query_dim, bias=True), - nn.ReLU(), - nn.Linear(query_dim, 3*K, bias=True)) - self.attention_weights = None - self.mu_prev = None - self.init_layers() - - def init_layers(self): - torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) - torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) - - def init_states(self, inputs): - if self.J is None or inputs.shape[1] > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]).to(inputs.device).expand([inputs.shape[0], self.K, inputs.shape[1]]) - self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) - self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) - - # pylint: disable=R0201 - # pylint: disable=unused-argument - def preprocess_inputs(self, inputs): - return None - - def forward(self, query, inputs, processed_inputs, mask): - """ - shapes: - query: B x D_attention_rnn - inputs: B x T_in x D_encoder - processed_inputs: place_holder - mask: B x T_in - """ - gbk_t = self.N_a(query) - gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) - - # attention model parameters - # each B x K - g_t = gbk_t[:, 0, :] - b_t = gbk_t[:, 1, :] - k_t = gbk_t[:, 2, :] - - # attention GMM parameters - sig_t = torch.nn.functional.softplus(b_t) + self.eps - - mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) - g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps - - # each B x K x T_in - g_t = g_t.unsqueeze(2).expand(g_t.size(0), - g_t.size(1), - inputs.size(1)) - sig_t = sig_t.unsqueeze(2).expand_as(g_t) - mu_t_ = mu_t.unsqueeze(2).expand_as(g_t) - j = self.J[:g_t.size(0), :, :inputs.size(1)] - - # attention weights - phi_t = g_t * torch.exp(-0.5 * (mu_t_ - j)**2 / (sig_t**2)) - alpha_t = self.COEF * torch.sum(phi_t, 1) - - # apply masking - if mask is not None: - alpha_t.data.masked_fill_(~mask, self._mask_value) - - context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) - self.attention_weights = alpha_t - self.mu_prev = mu_t - return context - - class OriginalAttention(nn.Module): """Following the methods proposed here: - https://arxiv.org/abs/1712.05884 @@ -365,6 +289,82 @@ class OriginalAttention(nn.Module): return context +class GravesAttention(nn.Module): + """ Graves attention as described here: + - https://arxiv.org/abs/1910.10288 + """ + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) + + def __init__(self, query_dim, K): + super(GravesAttention, self).__init__() + self._mask_value = 0.0 + self.K = K + # self.attention_alignment = 0.05 + self.eps = 1e-5 + self.J = None + self.N_a = nn.Sequential( + nn.Linear(query_dim, query_dim, bias=True), + nn.ReLU(), + nn.Linear(query_dim, 3*K, bias=True)) + self.attention_weights = None + self.mu_prev = None + self.init_layers() + + def init_layers(self): + torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) + torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) + + def init_states(self, inputs): + if self.J is None or inputs.shape[1] > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1]).to(inputs.device) + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + # pylint: disable=R0201 + # pylint: disable=unused-argument + def preprocess_inputs(self, inputs): + return None + + def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: B x D_attention_rnn + inputs: B x T_in x D_encoder + processed_inputs: place_holder + mask: B x T_in + """ + gbk_t = self.N_a(query) + gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) + + # attention model parameters + # each B x K + g_t = gbk_t[:, 0, :] + b_t = gbk_t[:, 1, :] + k_t = gbk_t[:, 2, :] + + # attention GMM parameters + sig_t = torch.nn.functional.softplus(b_t) + self.eps + + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps + + # each B x K x T_in + j = self.J[:inputs.size(1)] + + # attention weights + phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2)) + alpha_t = self.COEF * torch.sum(phi_t, 1) + + # apply masking + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t + return context + + def init_attn(attn_type, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn,