From 5e148038be5971f2c7c811d46a1d7b28c759ecda Mon Sep 17 00:00:00 2001 From: root Date: Thu, 9 Jan 2020 15:56:09 +0100 Subject: [PATCH] simpler gmm attention implementaiton --- config.json | 2 +- layers/common_layers.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/config.json b/config.json index 91863c4c..d23246a7 100644 --- a/config.json +++ b/config.json @@ -109,7 +109,7 @@ [ { "name": "ljspeech", - "path": "/data5/ro/shared/data/keithito/LJSpeech-1.1/", + "path": "/root/LJSpeech-1.1/", // "path": "/home/erogol/Data/LJSpeech-1.1", "meta_file_train": "metadata_train.csv", "meta_file_val": "metadata_val.csv" diff --git a/layers/common_layers.py b/layers/common_layers.py index 8b8ff073..112760b3 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -132,7 +132,7 @@ class GravesAttention(nn.Module): def init_states(self, inputs): if self.J is None or inputs.shape[1] > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]).to(inputs.device).expand([inputs.shape[0], self.K, inputs.shape[1]]) + self.J = torch.arange(0, inputs.shape[1]+1).to(inputs.device) + 0.5 self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) @@ -164,17 +164,14 @@ class GravesAttention(nn.Module): mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps - # each B x K x T_in - g_t = g_t.unsqueeze(2).expand(g_t.size(0), - g_t.size(1), - inputs.size(1)) - sig_t = sig_t.unsqueeze(2).expand_as(g_t) - mu_t_ = mu_t.unsqueeze(2).expand_as(g_t) - j = self.J[:g_t.size(0), :, :inputs.size(1)] + j = self.J[:inputs.size(1)+1] # attention weights - phi_t = g_t * torch.exp(-0.5 * (mu_t_ - j)**2 / (sig_t**2)) + phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2)) + + # discritize attention weights alpha_t = self.COEF * torch.sum(phi_t, 1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] # apply masking if mask is not None: