From 98edb7a4f8d7e99ab7dcdca036d27762f27e4dd9 Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Tue, 23 Jul 2019 18:38:09 +0200 Subject: [PATCH 1/7] renamed attention_rnn to query_rnn --- layers/common_layers.py | 24 +++++++------- layers/tacotron.py | 57 ++++++++++++++++---------------- layers/tacotron2.py | 72 ++++++++++++++++++++--------------------- 3 files changed, 79 insertions(+), 74 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index 2edf0dab..77ce4f4a 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -108,19 +108,19 @@ class LocationLayer(nn.Module): class Attention(nn.Module): # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + def __init__(self, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, trans_agent, forward_attn_mask): super(Attention, self).__init__() self.query_layer = Linear( - attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') + query_dim, attention_dim, bias=False, init_gain='tanh') self.inputs_layer = Linear( embedding_dim, attention_dim, bias=False, init_gain='tanh') self.v = Linear(attention_dim, 1, bias=True) if trans_agent: self.ta = nn.Linear( - attention_rnn_dim + embedding_dim, 1, bias=True) + query_dim + embedding_dim, 1, bias=True) if location_attention: self.location_layer = LocationLayer( attention_dim, @@ -203,11 +203,12 @@ class Attention(nn.Module): def apply_forward_attention(self, inputs, alignment, query): # forward attention - prev_alpha = F.pad(self.alpha[:, :-1].clone(), - (1, 0, 0, 0)).to(inputs.device) + prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device), + (1, 0, 0, 0)) # compute transition potentials - alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + - self.u * prev_alpha) + 1e-8) * alignment + alpha = ((1 - self.u) * self.alpha + + self.u * prev_alpha + + 1e-8) * alignment # force incremental alignment if not self.training and self.forward_attn_mask: _, n = prev_alpha.max(1) @@ -231,19 +232,20 @@ class Attention(nn.Module): self.u = torch.sigmoid(self.ta(ta_input)) return context, self.alpha - def forward(self, attention_hidden_state, inputs, processed_inputs, mask): + def forward(self, query, inputs, processed_inputs, mask): if self.location_attention: attention, processed_query = self.get_location_attention( - attention_hidden_state, processed_inputs) + query, processed_inputs) else: attention, processed_query = self.get_attention( - attention_hidden_state, processed_inputs) + query, processed_inputs) # apply masking if mask is not None: attention.data.masked_fill_(1 - mask, self._mask_value) # apply windowing - only in eval mode if not self.training and self.windowing: attention = self.apply_windowing(attention, inputs) + # normalize attention values if self.norm == "softmax": alignment = torch.softmax(attention, dim=-1) @@ -258,7 +260,7 @@ class Attention(nn.Module): # apply forward attention if enabled if self.forward_attn: context, self.attention_weights = self.apply_forward_attention( - inputs, alignment, attention_hidden_state) + inputs, alignment, query) else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) diff --git a/layers/tacotron.py b/layers/tacotron.py index b71ddbc3..068ae7cc 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -283,6 +283,7 @@ class Decoder(nn.Module): self.memory_size = memory_size if memory_size > 0 else r self.memory_dim = memory_dim self.separate_stopnet = separate_stopnet + self.query_dim = 256 # memory -> |Prenet| -> processed_memory self.prenet = Prenet( memory_dim * self.memory_size, @@ -290,18 +291,18 @@ class Decoder(nn.Module): prenet_dropout, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.attention_rnn = nn.GRUCell(in_features + 128, 256) - self.attention_layer = Attention(attention_rnn_dim=256, - embedding_dim=in_features, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_windowing, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask) + self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim) + self.attention = Attention(query_dim=self.query_dim, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -310,7 +311,7 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) # learn init values instead of zero init. - self.attention_rnn_init = nn.Embedding(1, 256) + self.query_rnn_init = nn.Embedding(1, 256) self.memory_init = nn.Embedding(1, self.memory_size * memory_dim) self.decoder_rnn_inits = nn.Embedding(2, 256) self.stopnet = StopNet(256 + memory_dim * r) @@ -347,18 +348,18 @@ class Decoder(nn.Module): self.memory_input = self.memory_init(inputs.data.new_zeros(B).long()) # decoder states - self.attention_rnn_hidden = self.attention_rnn_init( + self.query = self.query_rnn_init( inputs.data.new_zeros(B).long()) self.decoder_rnn_hiddens = [ self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long()) for idx in range(len(self.decoder_rnns)) ] - self.current_context_vec = inputs.data.new(B, self.in_features).zero_() + self.context_vec = inputs.data.new(B, self.in_features).zero_() # attention states self.attention = inputs.data.new(B, T).zero_() self.attention_cum = inputs.data.new(B, T).zero_() # cache attention inputs - self.processed_inputs = self.attention_layer.inputs_layer(inputs) + self.processed_inputs = self.attention.inputs_layer(inputs) def _parse_outputs(self, outputs, attentions, stop_tokens): # Back to batch first @@ -370,13 +371,15 @@ class Decoder(nn.Module): def decode(self, inputs, mask=None): # Prenet processed_memory = self.prenet(self.memory_input) + # Attention RNN - self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden) - self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) - # Concat RNN output and attention context vector + self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query) + self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask) + + # Concat query and attention context vector decoder_input = self.project_to_decoder_in( - torch.cat((self.attention_rnn_hidden, self.current_context_vec), - -1)) + torch.cat((self.query, self.context_vec), -1)) + # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( @@ -384,18 +387,18 @@ class Decoder(nn.Module): # Residual connection decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_output = decoder_input - del decoder_input + # predict mel vectors from decoder vectors output = self.proj_to_mel(decoder_output) output = torch.sigmoid(output) + # predict stop token stopnet_input = torch.cat([decoder_output, output], -1) - del decoder_output if self.separate_stopnet: stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - return output, stop_token, self.attention_layer.attention_weights + return output, stop_token, self.attention.attention_weights def _update_memory_queue(self, new_memory): if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size: @@ -427,7 +430,7 @@ class Decoder(nn.Module): stop_tokens = [] t = 0 self._init_states(inputs) - self.attention_layer.init_states(inputs) + self.attention.init_states(inputs) while len(outputs) < memory.size(0): if t > 0: new_memory = memory[t - 1] @@ -453,8 +456,8 @@ class Decoder(nn.Module): stop_tokens = [] t = 0 self._init_states(inputs) - self.attention_layer.init_win_idx() - self.attention_layer.init_states(inputs) + self.attention.init_win_idx() + self.attention.init_states(inputs) while True: if t > 0: new_memory = outputs[-1] diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 802f158e..ba52abe2 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -104,7 +104,7 @@ class Decoder(nn.Module): self.r = r self.encoder_embedding_dim = in_features self.separate_stopnet = separate_stopnet - self.attention_rnn_dim = 1024 + self.query_dim = 1024 self.decoder_rnn_dim = 1024 self.prenet_dim = 256 self.max_decoder_steps = 1000 @@ -116,22 +116,22 @@ class Decoder(nn.Module): prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False) - self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, - self.attention_rnn_dim) + self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features, + self.query_dim) - self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim, - embedding_dim=in_features, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_win, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask) + self.attention = Attention(query_dim=self.query_dim, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask) - self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features, self.decoder_rnn_dim, 1) self.linear_projection = Linear(self.decoder_rnn_dim + in_features, @@ -145,7 +145,7 @@ class Decoder(nn.Module): bias=True, init_gain='sigmoid')) - self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim) + self.query_rnn_init = nn.Embedding(1, self.query_dim) self.go_frame_init = nn.Embedding(1, self.mel_channels * r) self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) self.memory_truncated = None @@ -160,10 +160,10 @@ class Decoder(nn.Module): # T = inputs.size(1) if not keep_states: - self.attention_hidden = self.attention_rnn_init( + self.query = self.query_rnn_init( inputs.data.new_zeros(B).long()) - self.attention_cell = Variable( - inputs.data.new(B, self.attention_rnn_dim).zero_()) + self.query_rnn_cell_state = Variable( + inputs.data.new(B, self.query_dim).zero_()) self.decoder_hidden = self.decoder_rnn_inits( inputs.data.new_zeros(B).long()) @@ -174,7 +174,7 @@ class Decoder(nn.Module): inputs.data.new(B, self.encoder_embedding_dim).zero_()) self.inputs = inputs - self.processed_inputs = self.attention_layer.inputs_layer(inputs) + self.processed_inputs = self.attention.inputs_layer(inputs) self.mask = mask def _reshape_memory(self, memories): @@ -193,18 +193,18 @@ class Decoder(nn.Module): return outputs, stop_tokens, alignments def decode(self, memory): - cell_input = torch.cat((memory, self.context), -1) - self.attention_hidden, self.attention_cell = self.attention_rnn( - cell_input, (self.attention_hidden, self.attention_cell)) - self.attention_hidden = F.dropout( - self.attention_hidden, self.p_attention_dropout, self.training) - self.attention_cell = F.dropout( - self.attention_cell, self.p_attention_dropout, self.training) + query_input = torch.cat((memory, self.context), -1) + self.query, self.query_rnn_cell_state = self.query_rnn( + query_input, (self.query, self.query_rnn_cell_state)) + self.query = F.dropout( + self.query, self.p_attention_dropout, self.training) + self.query_rnn_cell_state = F.dropout( + self.query_rnn_cell_state, self.p_attention_dropout, self.training) - self.context = self.attention_layer(self.attention_hidden, self.inputs, - self.processed_inputs, self.mask) + self.context = self.attention(self.query, self.inputs, + self.processed_inputs, self.mask) - memory = torch.cat((self.attention_hidden, self.context), -1) + memory = torch.cat((self.query, self.context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( memory, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout(self.decoder_hidden, @@ -223,7 +223,7 @@ class Decoder(nn.Module): stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - return decoder_output, stop_token, self.attention_layer.attention_weights + return decoder_output, stop_token, self.attention.attention_weights def forward(self, inputs, memories, mask): memory = self.get_go_frame(inputs).unsqueeze(0) @@ -232,7 +232,7 @@ class Decoder(nn.Module): memories = self.prenet(memories) self._init_states(inputs, mask=mask) - self.attention_layer.init_states(inputs) + self.attention.init_states(inputs) outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: @@ -251,8 +251,8 @@ class Decoder(nn.Module): memory = self.get_go_frame(inputs) self._init_states(inputs, mask=None) - self.attention_layer.init_win_idx() - self.attention_layer.init_states(inputs) + self.attention.init_win_idx() + self.attention.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [True, False, False] @@ -295,8 +295,8 @@ class Decoder(nn.Module): else: self._init_states(inputs, mask=None, keep_states=True) - self.attention_layer.init_win_idx() - self.attention_layer.init_states(inputs) + self.attention.init_win_idx() + self.attention.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [True, False, False] stop_count = 0 From 82db35530f327afbaea99f249a23369faba30513 Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Tue, 23 Jul 2019 19:33:56 +0200 Subject: [PATCH 2/7] unused var --- layers/common_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index 77ce4f4a..bc353be3 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -234,10 +234,10 @@ class Attention(nn.Module): def forward(self, query, inputs, processed_inputs, mask): if self.location_attention: - attention, processed_query = self.get_location_attention( + attention, _ = self.get_location_attention( query, processed_inputs) else: - attention, processed_query = self.get_attention( + attention, _ = self.get_attention( query, processed_inputs) # apply masking if mask is not None: From fb7c5b1996532ccc5b214b9b6e462c574b037dbd Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Tue, 23 Jul 2019 20:02:31 +0200 Subject: [PATCH 3/7] unused instance vars --- layers/tacotron.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 068ae7cc..31d6cd84 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -355,9 +355,6 @@ class Decoder(nn.Module): for idx in range(len(self.decoder_rnns)) ] self.context_vec = inputs.data.new(B, self.in_features).zero_() - # attention states - self.attention = inputs.data.new(B, T).zero_() - self.attention_cum = inputs.data.new(B, T).zero_() # cache attention inputs self.processed_inputs = self.attention.inputs_layer(inputs) From a6118564d578f314dfa787a80cc288dba2228dfa Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Wed, 24 Jul 2019 11:46:34 +0200 Subject: [PATCH 4/7] renamed query_rnn back to attention_rnn --- layers/tacotron.py | 12 +++++++----- layers/tacotron2.py | 18 +++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 31d6cd84..40225fa5 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -291,7 +291,9 @@ class Decoder(nn.Module): prenet_dropout, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim) + # attention_rnn generates queries for the attention mechanism + self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim) + self.attention = Attention(query_dim=self.query_dim, embedding_dim=in_features, attention_dim=128, @@ -311,7 +313,7 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) # learn init values instead of zero init. - self.query_rnn_init = nn.Embedding(1, 256) + self.attention_rnn_init = nn.Embedding(1, 256) self.memory_init = nn.Embedding(1, self.memory_size * memory_dim) self.decoder_rnn_inits = nn.Embedding(2, 256) self.stopnet = StopNet(256 + memory_dim * r) @@ -348,7 +350,7 @@ class Decoder(nn.Module): self.memory_input = self.memory_init(inputs.data.new_zeros(B).long()) # decoder states - self.query = self.query_rnn_init( + self.query = self.attention_rnn_init( inputs.data.new_zeros(B).long()) self.decoder_rnn_hiddens = [ self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long()) @@ -369,8 +371,8 @@ class Decoder(nn.Module): # Prenet processed_memory = self.prenet(self.memory_input) - # Attention RNN - self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query) + # Attention + self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query) self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask) # Concat query and attention context vector diff --git a/layers/tacotron2.py b/layers/tacotron2.py index ba52abe2..358d1807 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -116,8 +116,8 @@ class Decoder(nn.Module): prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False) - self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features, - self.query_dim) + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, + self.query_dim) self.attention = Attention(query_dim=self.query_dim, embedding_dim=in_features, @@ -145,7 +145,7 @@ class Decoder(nn.Module): bias=True, init_gain='sigmoid')) - self.query_rnn_init = nn.Embedding(1, self.query_dim) + self.attention_rnn_init = nn.Embedding(1, self.query_dim) self.go_frame_init = nn.Embedding(1, self.mel_channels * r) self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) self.memory_truncated = None @@ -160,9 +160,9 @@ class Decoder(nn.Module): # T = inputs.size(1) if not keep_states: - self.query = self.query_rnn_init( + self.query = self.attention_rnn_init( inputs.data.new_zeros(B).long()) - self.query_rnn_cell_state = Variable( + self.attention_rnn_cell_state = Variable( inputs.data.new(B, self.query_dim).zero_()) self.decoder_hidden = self.decoder_rnn_inits( @@ -194,12 +194,12 @@ class Decoder(nn.Module): def decode(self, memory): query_input = torch.cat((memory, self.context), -1) - self.query, self.query_rnn_cell_state = self.query_rnn( - query_input, (self.query, self.query_rnn_cell_state)) + self.query, self.attention_rnn_cell_state = self.attention_rnn( + query_input, (self.query, self.attention_rnn_cell_state)) self.query = F.dropout( self.query, self.p_attention_dropout, self.training) - self.query_rnn_cell_state = F.dropout( - self.query_rnn_cell_state, self.p_attention_dropout, self.training) + self.attention_rnn_cell_state = F.dropout( + self.attention_rnn_cell_state, self.p_attention_dropout, self.training) self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) From 40f56f9b000bb03384ebe883c03380b260a6a205 Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Wed, 24 Jul 2019 11:47:06 +0200 Subject: [PATCH 5/7] simplified code for fwd attn --- layers/common_layers.py | 42 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index bc353be3..bfdd6775 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -201,17 +201,17 @@ class Attention(nn.Module): self.win_idx = torch.argmax(attention, 1).long()[0].item() return attention - def apply_forward_attention(self, inputs, alignment, query): + def apply_forward_attention(self, alignment): # forward attention - prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device), - (1, 0, 0, 0)) + fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), + (1, 0, 0, 0)) # compute transition potentials alpha = ((1 - self.u) * self.alpha - + self.u * prev_alpha + + self.u * fwd_shifted_alpha + 1e-8) * alignment # force incremental alignment if not self.training and self.forward_attn_mask: - _, n = prev_alpha.max(1) + _, n = fwd_shifted_alpha.max(1) val, n2 = alpha.max(1) for b in range(alignment.shape[0]): alpha[b, n[b] + 3:] = 0 @@ -221,16 +221,9 @@ class Attention(nn.Module): alpha[b, (n[b] - 2 )] = 0.01 * val[b] # smoothing factor for the prev step - # compute attention weights - self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) - # compute context - context = torch.bmm(self.alpha.unsqueeze(1), inputs) - context = context.squeeze(1) - # compute transition agent - if self.trans_agent: - ta_input = torch.cat([context, query.squeeze(1)], dim=-1) - self.u = torch.sigmoid(self.ta(ta_input)) - return context, self.alpha + # renormalize attention weights + alpha = alpha / alpha.sum(dim=1, keepdim=True) + return alpha def forward(self, query, inputs, processed_inputs, mask): if self.location_attention: @@ -254,15 +247,20 @@ class Attention(nn.Module): attention).sum( dim=1, keepdim=True) else: - raise RuntimeError("Unknown value for attention norm type") + raise ValueError("Unknown value for attention norm type") if self.location_attention: self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: - context, self.attention_weights = self.apply_forward_attention( - inputs, alignment, query) - else: - context = torch.bmm(alignment.unsqueeze(1), inputs) - context = context.squeeze(1) - self.attention_weights = alignment + alignment = self.apply_forward_attention(alignment) + self.alpha = alignment + + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + self.attention_weights = alignment + + # compute transition agent + if self.forward_attn and self.trans_agent: + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) + self.u = torch.sigmoid(self.ta(ta_input)) return context From f3dac0aa840a893f7222b6444d5bc7f5f40d623d Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Wed, 24 Jul 2019 11:49:07 +0200 Subject: [PATCH 6/7] updating location attn after calculating fwd attention --- layers/common_layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index bfdd6775..a652b8a6 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -248,13 +248,14 @@ class Attention(nn.Module): dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") - if self.location_attention: - self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: alignment = self.apply_forward_attention(alignment) self.alpha = alignment + if self.location_attention: + self.update_location_attention(alignment) + context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) self.attention_weights = alignment From ab42396fbfbd647f8f7f67f660250d9f75219643 Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Thu, 25 Jul 2019 13:04:41 +0200 Subject: [PATCH 7/7] undo loc attn after fwd attn --- layers/common_layers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index a652b8a6..0a7216ef 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -248,14 +248,15 @@ class Attention(nn.Module): dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") + + if self.location_attention: + self.update_location_attention(alignment) + # apply forward attention if enabled if self.forward_attn: alignment = self.apply_forward_attention(alignment) self.alpha = alignment - if self.location_attention: - self.update_location_attention(alignment) - context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) self.attention_weights = alignment