diff --git a/config.json b/config.json index ef999fa9..0bf6c378 100644 --- a/config.json +++ b/config.json @@ -108,7 +108,7 @@ [ { "name": "ljspeech", - "path": "/data5/ro/shared/data/keithito/LJSpeech-1.1/", + "path": "/root/LJSpeech-1.1/", // "path": "/home/erogol/Data/LJSpeech-1.1", "meta_file_train": "metadata_train.csv", "meta_file_val": "metadata_val.csv" diff --git a/layers/common_layers.py b/layers/common_layers.py index c2b042b0..5365d605 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -110,6 +110,84 @@ class LocationLayer(nn.Module): return processed_attention +class GravesAttention(nn.Module): + """ Graves attention as described here: + - https://arxiv.org/abs/1910.10288 + """ + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) + + def __init__(self, query_dim, K): + super(GravesAttention, self).__init__() + self._mask_value = 0.0 + self.K = K + # self.attention_alignment = 0.05 + self.eps = 1e-5 + self.J = None + self.N_a = nn.Sequential( + nn.Linear(query_dim, query_dim, bias=True), + nn.ReLU(), + nn.Linear(query_dim, 3*K, bias=True)) + self.attention_weights = None + self.mu_prev = None + self.init_layers() + + def init_layers(self): + torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) + torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) + + def init_states(self, inputs): + if self.J is None or inputs.shape[1] > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1]+1).to(inputs.device) + 0.5 + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + # pylint: disable=R0201 + # pylint: disable=unused-argument + def preprocess_inputs(self, inputs): + return None + + def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: B x D_attention_rnn + inputs: B x T_in x D_encoder + processed_inputs: place_holder + mask: B x T_in + """ + gbk_t = self.N_a(query) + gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) + + # attention model parameters + # each B x K + g_t = gbk_t[:, 0, :] + b_t = gbk_t[:, 1, :] + k_t = gbk_t[:, 2, :] + + # attention GMM parameters + sig_t = torch.nn.functional.softplus(b_t) + self.eps + + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps + + j = self.J[:inputs.size(1)+1] + + # attention weights + phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2)) + + # discritize attention weights + alpha_t = self.COEF * torch.sum(phi_t, 1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + + # apply masking + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t + return context + + class OriginalAttention(nn.Module): """Following the methods proposed here: - https://arxiv.org/abs/1712.05884