Merge pull request #245 from twerkmeister/attention_naming

renamed attention_rnn to query_rnn
This commit is contained in:
Eren Gölge 2019-08-13 10:54:08 +02:00 committed by GitHub
commit 69abea55d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 100 additions and 96 deletions

View File

@ -108,19 +108,19 @@ class LocationLayer(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
# Pylint gets confused by PyTorch conventions here # Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init #pylint: disable=attribute-defined-outside-init
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, def __init__(self, query_dim, embedding_dim, attention_dim,
location_attention, attention_location_n_filters, location_attention, attention_location_n_filters,
attention_location_kernel_size, windowing, norm, forward_attn, attention_location_kernel_size, windowing, norm, forward_attn,
trans_agent, forward_attn_mask): trans_agent, forward_attn_mask):
super(Attention, self).__init__() super(Attention, self).__init__()
self.query_layer = Linear( self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') query_dim, attention_dim, bias=False, init_gain='tanh')
self.inputs_layer = Linear( self.inputs_layer = Linear(
embedding_dim, attention_dim, bias=False, init_gain='tanh') embedding_dim, attention_dim, bias=False, init_gain='tanh')
self.v = Linear(attention_dim, 1, bias=True) self.v = Linear(attention_dim, 1, bias=True)
if trans_agent: if trans_agent:
self.ta = nn.Linear( self.ta = nn.Linear(
attention_rnn_dim + embedding_dim, 1, bias=True) query_dim + embedding_dim, 1, bias=True)
if location_attention: if location_attention:
self.location_layer = LocationLayer( self.location_layer = LocationLayer(
attention_dim, attention_dim,
@ -201,16 +201,17 @@ class Attention(nn.Module):
self.win_idx = torch.argmax(attention, 1).long()[0].item() self.win_idx = torch.argmax(attention, 1).long()[0].item()
return attention return attention
def apply_forward_attention(self, inputs, alignment, query): def apply_forward_attention(self, alignment):
# forward attention # forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device),
(1, 0, 0, 0)).to(inputs.device) (1, 0, 0, 0))
# compute transition potentials # compute transition potentials
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + alpha = ((1 - self.u) * self.alpha
self.u * prev_alpha) + 1e-8) * alignment + self.u * fwd_shifted_alpha
+ 1e-8) * alignment
# force incremental alignment # force incremental alignment
if not self.training and self.forward_attn_mask: if not self.training and self.forward_attn_mask:
_, n = prev_alpha.max(1) _, n = fwd_shifted_alpha.max(1)
val, n2 = alpha.max(1) val, n2 = alpha.max(1)
for b in range(alignment.shape[0]): for b in range(alignment.shape[0]):
alpha[b, n[b] + 3:] = 0 alpha[b, n[b] + 3:] = 0
@ -220,30 +221,24 @@ class Attention(nn.Module):
alpha[b, alpha[b,
(n[b] - 2 (n[b] - 2
)] = 0.01 * val[b] # smoothing factor for the prev step )] = 0.01 * val[b] # smoothing factor for the prev step
# compute attention weights # renormalize attention weights
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) alpha = alpha / alpha.sum(dim=1, keepdim=True)
# compute context return alpha
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
context = context.squeeze(1)
# compute transition agent
if self.trans_agent:
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input))
return context, self.alpha
def forward(self, attention_hidden_state, inputs, processed_inputs, mask): def forward(self, query, inputs, processed_inputs, mask):
if self.location_attention: if self.location_attention:
attention, processed_query = self.get_location_attention( attention, _ = self.get_location_attention(
attention_hidden_state, processed_inputs) query, processed_inputs)
else: else:
attention, processed_query = self.get_attention( attention, _ = self.get_attention(
attention_hidden_state, processed_inputs) query, processed_inputs)
# apply masking # apply masking
if mask is not None: if mask is not None:
attention.data.masked_fill_(1 - mask, self._mask_value) attention.data.masked_fill_(1 - mask, self._mask_value)
# apply windowing - only in eval mode # apply windowing - only in eval mode
if not self.training and self.windowing: if not self.training and self.windowing:
attention = self.apply_windowing(attention, inputs) attention = self.apply_windowing(attention, inputs)
# normalize attention values # normalize attention values
if self.norm == "softmax": if self.norm == "softmax":
alignment = torch.softmax(attention, dim=-1) alignment = torch.softmax(attention, dim=-1)
@ -252,15 +247,22 @@ class Attention(nn.Module):
attention).sum( attention).sum(
dim=1, keepdim=True) dim=1, keepdim=True)
else: else:
raise RuntimeError("Unknown value for attention norm type") raise ValueError("Unknown value for attention norm type")
if self.location_attention: if self.location_attention:
self.update_location_attention(alignment) self.update_location_attention(alignment)
# apply forward attention if enabled # apply forward attention if enabled
if self.forward_attn: if self.forward_attn:
context, self.attention_weights = self.apply_forward_attention( alignment = self.apply_forward_attention(alignment)
inputs, alignment, attention_hidden_state) self.alpha = alignment
else:
context = torch.bmm(alignment.unsqueeze(1), inputs) context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1) context = context.squeeze(1)
self.attention_weights = alignment self.attention_weights = alignment
# compute transition agent
if self.forward_attn and self.trans_agent:
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input))
return context return context

View File

@ -283,6 +283,7 @@ class Decoder(nn.Module):
self.memory_size = memory_size if memory_size > 0 else r self.memory_size = memory_size if memory_size > 0 else r
self.memory_dim = memory_dim self.memory_dim = memory_dim
self.separate_stopnet = separate_stopnet self.separate_stopnet = separate_stopnet
self.query_dim = 256
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet( self.prenet = Prenet(
memory_dim * self.memory_size, memory_dim * self.memory_size,
@ -290,18 +291,20 @@ class Decoder(nn.Module):
prenet_dropout, prenet_dropout,
out_features=[256, 128]) out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = nn.GRUCell(in_features + 128, 256) # attention_rnn generates queries for the attention mechanism
self.attention_layer = Attention(attention_rnn_dim=256, self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
embedding_dim=in_features,
attention_dim=128, self.attention = Attention(query_dim=self.query_dim,
location_attention=location_attn, embedding_dim=in_features,
attention_location_n_filters=32, attention_dim=128,
attention_location_kernel_size=31, location_attention=location_attn,
windowing=attn_windowing, attention_location_n_filters=32,
norm=attn_norm, attention_location_kernel_size=31,
forward_attn=forward_attn, windowing=attn_windowing,
trans_agent=trans_agent, norm=attn_norm,
forward_attn_mask=forward_attn_mask) forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256) self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state # decoder_RNN_input -> |RNN| -> RNN_state
@ -347,18 +350,15 @@ class Decoder(nn.Module):
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long()) self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
# decoder states # decoder states
self.attention_rnn_hidden = self.attention_rnn_init( self.query = self.attention_rnn_init(
inputs.data.new_zeros(B).long()) inputs.data.new_zeros(B).long())
self.decoder_rnn_hiddens = [ self.decoder_rnn_hiddens = [
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long()) self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
for idx in range(len(self.decoder_rnns)) for idx in range(len(self.decoder_rnns))
] ]
self.current_context_vec = inputs.data.new(B, self.in_features).zero_() self.context_vec = inputs.data.new(B, self.in_features).zero_()
# attention states
self.attention = inputs.data.new(B, T).zero_()
self.attention_cum = inputs.data.new(B, T).zero_()
# cache attention inputs # cache attention inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs) self.processed_inputs = self.attention.inputs_layer(inputs)
def _parse_outputs(self, outputs, attentions, stop_tokens): def _parse_outputs(self, outputs, attentions, stop_tokens):
# Back to batch first # Back to batch first
@ -370,13 +370,15 @@ class Decoder(nn.Module):
def decode(self, inputs, mask=None): def decode(self, inputs, mask=None):
# Prenet # Prenet
processed_memory = self.prenet(self.memory_input) processed_memory = self.prenet(self.memory_input)
# Attention RNN
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden) # Attention
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
# Concat RNN output and attention context vector self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
# Concat query and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.current_context_vec), torch.cat((self.query, self.context_vec), -1))
-1))
# Pass through the decoder RNNs # Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)): for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
@ -384,18 +386,18 @@ class Decoder(nn.Module):
# Residual connection # Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input decoder_output = decoder_input
del decoder_input
# predict mel vectors from decoder vectors # predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output) output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output) output = torch.sigmoid(output)
# predict stop token # predict stop token
stopnet_input = torch.cat([decoder_output, output], -1) stopnet_input = torch.cat([decoder_output, output], -1)
del decoder_output
if self.separate_stopnet: if self.separate_stopnet:
stop_token = self.stopnet(stopnet_input.detach()) stop_token = self.stopnet(stopnet_input.detach())
else: else:
stop_token = self.stopnet(stopnet_input) stop_token = self.stopnet(stopnet_input)
return output, stop_token, self.attention_layer.attention_weights return output, stop_token, self.attention.attention_weights
def _update_memory_queue(self, new_memory): def _update_memory_queue(self, new_memory):
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size: if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
@ -427,7 +429,7 @@ class Decoder(nn.Module):
stop_tokens = [] stop_tokens = []
t = 0 t = 0
self._init_states(inputs) self._init_states(inputs)
self.attention_layer.init_states(inputs) self.attention.init_states(inputs)
while len(outputs) < memory.size(0): while len(outputs) < memory.size(0):
if t > 0: if t > 0:
new_memory = memory[t - 1] new_memory = memory[t - 1]
@ -453,8 +455,8 @@ class Decoder(nn.Module):
stop_tokens = [] stop_tokens = []
t = 0 t = 0
self._init_states(inputs) self._init_states(inputs)
self.attention_layer.init_win_idx() self.attention.init_win_idx()
self.attention_layer.init_states(inputs) self.attention.init_states(inputs)
while True: while True:
if t > 0: if t > 0:
new_memory = outputs[-1] new_memory = outputs[-1]

View File

@ -104,7 +104,7 @@ class Decoder(nn.Module):
self.r = r self.r = r
self.encoder_embedding_dim = in_features self.encoder_embedding_dim = in_features
self.separate_stopnet = separate_stopnet self.separate_stopnet = separate_stopnet
self.attention_rnn_dim = 1024 self.query_dim = 1024
self.decoder_rnn_dim = 1024 self.decoder_rnn_dim = 1024
self.prenet_dim = 256 self.prenet_dim = 256
self.max_decoder_steps = 1000 self.max_decoder_steps = 1000
@ -117,21 +117,21 @@ class Decoder(nn.Module):
[self.prenet_dim, self.prenet_dim], bias=False) [self.prenet_dim, self.prenet_dim], bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.attention_rnn_dim) self.query_dim)
self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim, self.attention = Attention(query_dim=self.query_dim,
embedding_dim=in_features, embedding_dim=in_features,
attention_dim=128, attention_dim=128,
location_attention=location_attn, location_attention=location_attn,
attention_location_n_filters=32, attention_location_n_filters=32,
attention_location_kernel_size=31, attention_location_kernel_size=31,
windowing=attn_win, windowing=attn_win,
norm=attn_norm, norm=attn_norm,
forward_attn=forward_attn, forward_attn=forward_attn,
trans_agent=trans_agent, trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask) forward_attn_mask=forward_attn_mask)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
self.decoder_rnn_dim, 1) self.decoder_rnn_dim, 1)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features, self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
@ -145,7 +145,7 @@ class Decoder(nn.Module):
bias=True, bias=True,
init_gain='sigmoid')) init_gain='sigmoid'))
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim) self.attention_rnn_init = nn.Embedding(1, self.query_dim)
self.go_frame_init = nn.Embedding(1, self.mel_channels * r) self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
self.memory_truncated = None self.memory_truncated = None
@ -160,10 +160,10 @@ class Decoder(nn.Module):
# T = inputs.size(1) # T = inputs.size(1)
if not keep_states: if not keep_states:
self.attention_hidden = self.attention_rnn_init( self.query = self.attention_rnn_init(
inputs.data.new_zeros(B).long()) inputs.data.new_zeros(B).long())
self.attention_cell = Variable( self.attention_rnn_cell_state = Variable(
inputs.data.new(B, self.attention_rnn_dim).zero_()) inputs.data.new(B, self.query_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits( self.decoder_hidden = self.decoder_rnn_inits(
inputs.data.new_zeros(B).long()) inputs.data.new_zeros(B).long())
@ -174,7 +174,7 @@ class Decoder(nn.Module):
inputs.data.new(B, self.encoder_embedding_dim).zero_()) inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.inputs = inputs self.inputs = inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs) self.processed_inputs = self.attention.inputs_layer(inputs)
self.mask = mask self.mask = mask
def _reshape_memory(self, memories): def _reshape_memory(self, memories):
@ -193,18 +193,18 @@ class Decoder(nn.Module):
return outputs, stop_tokens, alignments return outputs, stop_tokens, alignments
def decode(self, memory): def decode(self, memory):
cell_input = torch.cat((memory, self.context), -1) query_input = torch.cat((memory, self.context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn( self.query, self.attention_rnn_cell_state = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell)) query_input, (self.query, self.attention_rnn_cell_state))
self.attention_hidden = F.dropout( self.query = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training) self.query, self.p_attention_dropout, self.training)
self.attention_cell = F.dropout( self.attention_rnn_cell_state = F.dropout(
self.attention_cell, self.p_attention_dropout, self.training) self.attention_rnn_cell_state, self.p_attention_dropout, self.training)
self.context = self.attention_layer(self.attention_hidden, self.inputs, self.context = self.attention(self.query, self.inputs,
self.processed_inputs, self.mask) self.processed_inputs, self.mask)
memory = torch.cat((self.attention_hidden, self.context), -1) memory = torch.cat((self.query, self.context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
memory, (self.decoder_hidden, self.decoder_cell)) memory, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden, self.decoder_hidden = F.dropout(self.decoder_hidden,
@ -223,7 +223,7 @@ class Decoder(nn.Module):
stop_token = self.stopnet(stopnet_input.detach()) stop_token = self.stopnet(stopnet_input.detach())
else: else:
stop_token = self.stopnet(stopnet_input) stop_token = self.stopnet(stopnet_input)
return decoder_output, stop_token, self.attention_layer.attention_weights return decoder_output, stop_token, self.attention.attention_weights
def forward(self, inputs, memories, mask): def forward(self, inputs, memories, mask):
memory = self.get_go_frame(inputs).unsqueeze(0) memory = self.get_go_frame(inputs).unsqueeze(0)
@ -232,7 +232,7 @@ class Decoder(nn.Module):
memories = self.prenet(memories) memories = self.prenet(memories)
self._init_states(inputs, mask=mask) self._init_states(inputs, mask=mask)
self.attention_layer.init_states(inputs) self.attention.init_states(inputs)
outputs, stop_tokens, alignments = [], [], [] outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1: while len(outputs) < memories.size(0) - 1:
@ -251,8 +251,8 @@ class Decoder(nn.Module):
memory = self.get_go_frame(inputs) memory = self.get_go_frame(inputs)
self._init_states(inputs, mask=None) self._init_states(inputs, mask=None)
self.attention_layer.init_win_idx() self.attention.init_win_idx()
self.attention_layer.init_states(inputs) self.attention.init_states(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0 outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False] stop_flags = [True, False, False]
@ -295,8 +295,8 @@ class Decoder(nn.Module):
else: else:
self._init_states(inputs, mask=None, keep_states=True) self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx() self.attention.init_win_idx()
self.attention_layer.init_states(inputs) self.attention.init_states(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0 outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False] stop_flags = [True, False, False]
stop_count = 0 stop_count = 0