Update decoder accepting location sensitive attention

This commit is contained in:
Eren Golge 2018-05-17 08:03:40 -07:00
parent 3348d462d2
commit d43b4ccc1e
1 changed files with 16 additions and 8 deletions

View File

@ -223,7 +223,7 @@ class Decoder(nn.Module):
self.r = r self.r = r
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNN(256, in_features, 128) self.attention_rnn = AttentionRNN(256, in_features, 128)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256) self.project_to_decoder_in = nn.Linear(256+in_features, 256)
@ -252,6 +252,7 @@ class Decoder(nn.Module):
- memory: batch x #mels_pecs x mel_spec_dim - memory: batch x #mels_pecs x mel_spec_dim
""" """
B = inputs.size(0) B = inputs.size(0)
T = inputs.size(1)
# Run greedy decoding if memory is None # Run greedy decoding if memory is None
greedy = not self.training greedy = not self.training
if memory is not None: if memory is not None:
@ -269,11 +270,13 @@ class Decoder(nn.Module):
for _ in range(len(self.decoder_rnns))] for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 256).zero_() current_context_vec = inputs.data.new(B, 256).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
attention_vec = memory.data.new(B, T).zero_()
attention_vec_cum = memory.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim) # Time first (T_decoder, B, memory_dim)
if memory is not None: if memory is not None:
memory = memory.transpose(0, 1) memory = memory.transpose(0, 1)
outputs = [] outputs = []
alignments = [] attentions = []
stop_tokens = [] stop_tokens = []
t = 0 t = 0
memory_input = initial_memory memory_input = initial_memory
@ -286,8 +289,13 @@ class Decoder(nn.Module):
# Prenet # Prenet
processed_memory = self.prenet(memory_input) processed_memory = self.prenet(memory_input)
# Attention RNN # Attention RNN
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
processed_memory, current_context_vec, attention_rnn_hidden, inputs) attention_vec_cum.unsqueeze(1)),
dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
attention_vec_cum += attention_vec
attention_vec_cum /= (t + 1)
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1)) torch.cat((attention_rnn_hidden, current_context_vec), -1))
@ -304,14 +312,14 @@ class Decoder(nn.Module):
# predict stop token # predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden) stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
outputs += [output] outputs += [output]
alignments += [alignment] attentions += [attention]
stop_tokens += [stop_token] stop_tokens += [stop_token]
t += 1 t += 1
if (not greedy and self.training) or (greedy and memory is not None): if (not greedy and self.training) or (greedy and memory is not None):
if t >= T_decoder: if t >= T_decoder:
break break
else: else:
if t > inputs.shape[1]/2 and stop_token > 0.8: if t > inputs.shape[1]/2 and stop_token > 0.6:
break break
elif t > self.max_decoder_steps: elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \ print(" !! Decoder stopped with 'max_decoder_steps'. \
@ -319,10 +327,10 @@ class Decoder(nn.Module):
break break
assert greedy or len(outputs) == T_decoder assert greedy or len(outputs) == T_decoder
# Back to batch first # Back to batch first
alignments = torch.stack(alignments).transpose(0, 1) attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, alignments, stop_tokens return outputs, attentions, stop_tokens
class StopNet(nn.Module): class StopNet(nn.Module):