Update decoder accepting location sensitive attention

This commit is contained in:
Eren Golge 2018-05-17 08:03:40 -07:00
parent 3348d462d2
commit d43b4ccc1e
1 changed files with 16 additions and 8 deletions

View File

@ -223,7 +223,7 @@ class Decoder(nn.Module):
self.r = r
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNN(256, in_features, 128)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
@ -252,6 +252,7 @@ class Decoder(nn.Module):
- memory: batch x #mels_pecs x mel_spec_dim
"""
B = inputs.size(0)
T = inputs.size(1)
# Run greedy decoding if memory is None
greedy = not self.training
if memory is not None:
@ -269,11 +270,13 @@ class Decoder(nn.Module):
for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 256).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
attention_vec = memory.data.new(B, T).zero_()
attention_vec_cum = memory.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
outputs = []
alignments = []
attentions = []
stop_tokens = []
t = 0
memory_input = initial_memory
@ -286,8 +289,13 @@ class Decoder(nn.Module):
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
attention_vec_cum.unsqueeze(1)),
dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
attention_vec_cum += attention_vec
attention_vec_cum /= (t + 1)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1))
@ -304,14 +312,14 @@ class Decoder(nn.Module):
# predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
outputs += [output]
alignments += [alignment]
attentions += [attention]
stop_tokens += [stop_token]
t += 1
if (not greedy and self.training) or (greedy and memory is not None):
if t >= T_decoder:
break
else:
if t > inputs.shape[1]/2 and stop_token > 0.8:
if t > inputs.shape[1]/2 and stop_token > 0.6:
break
elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \
@ -319,10 +327,10 @@ class Decoder(nn.Module):
break
assert greedy or len(outputs) == T_decoder
# Back to batch first
alignments = torch.stack(alignments).transpose(0, 1)
attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, alignments, stop_tokens
return outputs, attentions, stop_tokens
class StopNet(nn.Module):