renamed attention_rnn to query_rnn

This commit is contained in:
Thomas Werkmeister 2019-07-23 18:38:09 +02:00
parent f08b31fd8d
commit 98edb7a4f8
3 changed files with 79 additions and 74 deletions

View File

@ -108,19 +108,19 @@ class LocationLayer(nn.Module):
class Attention(nn.Module):
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
def __init__(self, query_dim, embedding_dim, attention_dim,
location_attention, attention_location_n_filters,
attention_location_kernel_size, windowing, norm, forward_attn,
trans_agent, forward_attn_mask):
super(Attention, self).__init__()
self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
query_dim, attention_dim, bias=False, init_gain='tanh')
self.inputs_layer = Linear(
embedding_dim, attention_dim, bias=False, init_gain='tanh')
self.v = Linear(attention_dim, 1, bias=True)
if trans_agent:
self.ta = nn.Linear(
attention_rnn_dim + embedding_dim, 1, bias=True)
query_dim + embedding_dim, 1, bias=True)
if location_attention:
self.location_layer = LocationLayer(
attention_dim,
@ -203,11 +203,12 @@ class Attention(nn.Module):
def apply_forward_attention(self, inputs, alignment, query):
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
(1, 0, 0, 0)).to(inputs.device)
prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device),
(1, 0, 0, 0))
# compute transition potentials
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment
alpha = ((1 - self.u) * self.alpha
+ self.u * prev_alpha
+ 1e-8) * alignment
# force incremental alignment
if not self.training and self.forward_attn_mask:
_, n = prev_alpha.max(1)
@ -231,19 +232,20 @@ class Attention(nn.Module):
self.u = torch.sigmoid(self.ta(ta_input))
return context, self.alpha
def forward(self, attention_hidden_state, inputs, processed_inputs, mask):
def forward(self, query, inputs, processed_inputs, mask):
if self.location_attention:
attention, processed_query = self.get_location_attention(
attention_hidden_state, processed_inputs)
query, processed_inputs)
else:
attention, processed_query = self.get_attention(
attention_hidden_state, processed_inputs)
query, processed_inputs)
# apply masking
if mask is not None:
attention.data.masked_fill_(1 - mask, self._mask_value)
# apply windowing - only in eval mode
if not self.training and self.windowing:
attention = self.apply_windowing(attention, inputs)
# normalize attention values
if self.norm == "softmax":
alignment = torch.softmax(attention, dim=-1)
@ -258,7 +260,7 @@ class Attention(nn.Module):
# apply forward attention if enabled
if self.forward_attn:
context, self.attention_weights = self.apply_forward_attention(
inputs, alignment, attention_hidden_state)
inputs, alignment, query)
else:
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)

View File

@ -283,6 +283,7 @@ class Decoder(nn.Module):
self.memory_size = memory_size if memory_size > 0 else r
self.memory_dim = memory_dim
self.separate_stopnet = separate_stopnet
self.query_dim = 256
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(
memory_dim * self.memory_size,
@ -290,18 +291,18 @@ class Decoder(nn.Module):
prenet_dropout,
out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
self.attention_layer = Attention(attention_rnn_dim=256,
embedding_dim=in_features,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_windowing,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim)
self.attention = Attention(query_dim=self.query_dim,
embedding_dim=in_features,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_windowing,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
@ -310,7 +311,7 @@ class Decoder(nn.Module):
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
# learn init values instead of zero init.
self.attention_rnn_init = nn.Embedding(1, 256)
self.query_rnn_init = nn.Embedding(1, 256)
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
self.decoder_rnn_inits = nn.Embedding(2, 256)
self.stopnet = StopNet(256 + memory_dim * r)
@ -347,18 +348,18 @@ class Decoder(nn.Module):
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
# decoder states
self.attention_rnn_hidden = self.attention_rnn_init(
self.query = self.query_rnn_init(
inputs.data.new_zeros(B).long())
self.decoder_rnn_hiddens = [
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
for idx in range(len(self.decoder_rnns))
]
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
self.context_vec = inputs.data.new(B, self.in_features).zero_()
# attention states
self.attention = inputs.data.new(B, T).zero_()
self.attention_cum = inputs.data.new(B, T).zero_()
# cache attention inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
self.processed_inputs = self.attention.inputs_layer(inputs)
def _parse_outputs(self, outputs, attentions, stop_tokens):
# Back to batch first
@ -370,13 +371,15 @@ class Decoder(nn.Module):
def decode(self, inputs, mask=None):
# Prenet
processed_memory = self.prenet(self.memory_input)
# Attention RNN
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
# Concat RNN output and attention context vector
self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
# Concat query and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
-1))
torch.cat((self.query, self.context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
@ -384,18 +387,18 @@ class Decoder(nn.Module):
# Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
del decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
# predict stop token
stopnet_input = torch.cat([decoder_output, output], -1)
del decoder_output
if self.separate_stopnet:
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
return output, stop_token, self.attention_layer.attention_weights
return output, stop_token, self.attention.attention_weights
def _update_memory_queue(self, new_memory):
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
@ -427,7 +430,7 @@ class Decoder(nn.Module):
stop_tokens = []
t = 0
self._init_states(inputs)
self.attention_layer.init_states(inputs)
self.attention.init_states(inputs)
while len(outputs) < memory.size(0):
if t > 0:
new_memory = memory[t - 1]
@ -453,8 +456,8 @@ class Decoder(nn.Module):
stop_tokens = []
t = 0
self._init_states(inputs)
self.attention_layer.init_win_idx()
self.attention_layer.init_states(inputs)
self.attention.init_win_idx()
self.attention.init_states(inputs)
while True:
if t > 0:
new_memory = outputs[-1]

View File

@ -104,7 +104,7 @@ class Decoder(nn.Module):
self.r = r
self.encoder_embedding_dim = in_features
self.separate_stopnet = separate_stopnet
self.attention_rnn_dim = 1024
self.query_dim = 1024
self.decoder_rnn_dim = 1024
self.prenet_dim = 256
self.max_decoder_steps = 1000
@ -116,22 +116,22 @@ class Decoder(nn.Module):
prenet_dropout,
[self.prenet_dim, self.prenet_dim], bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.attention_rnn_dim)
self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.query_dim)
self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim,
embedding_dim=in_features,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_win,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
self.attention = Attention(query_dim=self.query_dim,
embedding_dim=in_features,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_win,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
self.decoder_rnn_dim, 1)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
@ -145,7 +145,7 @@ class Decoder(nn.Module):
bias=True,
init_gain='sigmoid'))
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
self.query_rnn_init = nn.Embedding(1, self.query_dim)
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
self.memory_truncated = None
@ -160,10 +160,10 @@ class Decoder(nn.Module):
# T = inputs.size(1)
if not keep_states:
self.attention_hidden = self.attention_rnn_init(
self.query = self.query_rnn_init(
inputs.data.new_zeros(B).long())
self.attention_cell = Variable(
inputs.data.new(B, self.attention_rnn_dim).zero_())
self.query_rnn_cell_state = Variable(
inputs.data.new(B, self.query_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits(
inputs.data.new_zeros(B).long())
@ -174,7 +174,7 @@ class Decoder(nn.Module):
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.inputs = inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
self.processed_inputs = self.attention.inputs_layer(inputs)
self.mask = mask
def _reshape_memory(self, memories):
@ -193,18 +193,18 @@ class Decoder(nn.Module):
return outputs, stop_tokens, alignments
def decode(self, memory):
cell_input = torch.cat((memory, self.context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
self.attention_cell = F.dropout(
self.attention_cell, self.p_attention_dropout, self.training)
query_input = torch.cat((memory, self.context), -1)
self.query, self.query_rnn_cell_state = self.query_rnn(
query_input, (self.query, self.query_rnn_cell_state))
self.query = F.dropout(
self.query, self.p_attention_dropout, self.training)
self.query_rnn_cell_state = F.dropout(
self.query_rnn_cell_state, self.p_attention_dropout, self.training)
self.context = self.attention_layer(self.attention_hidden, self.inputs,
self.processed_inputs, self.mask)
self.context = self.attention(self.query, self.inputs,
self.processed_inputs, self.mask)
memory = torch.cat((self.attention_hidden, self.context), -1)
memory = torch.cat((self.query, self.context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
memory, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden,
@ -223,7 +223,7 @@ class Decoder(nn.Module):
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
return decoder_output, stop_token, self.attention_layer.attention_weights
return decoder_output, stop_token, self.attention.attention_weights
def forward(self, inputs, memories, mask):
memory = self.get_go_frame(inputs).unsqueeze(0)
@ -232,7 +232,7 @@ class Decoder(nn.Module):
memories = self.prenet(memories)
self._init_states(inputs, mask=mask)
self.attention_layer.init_states(inputs)
self.attention.init_states(inputs)
outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
@ -251,8 +251,8 @@ class Decoder(nn.Module):
memory = self.get_go_frame(inputs)
self._init_states(inputs, mask=None)
self.attention_layer.init_win_idx()
self.attention_layer.init_states(inputs)
self.attention.init_win_idx()
self.attention.init_states(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False]
@ -295,8 +295,8 @@ class Decoder(nn.Module):
else:
self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx()
self.attention_layer.init_states(inputs)
self.attention.init_win_idx()
self.attention.init_states(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False]
stop_count = 0