memoru queueing

This commit is contained in:
Eren Golge 2019-02-12 15:27:42 +01:00
parent ebc166b54f
commit 90f0cd640b
1 changed files with 13 additions and 5 deletions

View File

@ -298,16 +298,18 @@ class Decoder(nn.Module):
in_features (int): input vector (encoder output) sample size. in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size. memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step. r (int): number of outputs per time step.
memory_size (int): size of the past window. if <= 0 memory_size = r
""" """
def __init__(self, in_features, memory_dim, r, attn_windowing): def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.r = r self.r = r
self.in_features = in_features self.in_features = in_features
self.max_decoder_steps = 500 self.max_decoder_steps = 500
self.memory_size = memory_size if memory_size > 0 else r
self.memory_dim = memory_dim self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) self.prenet = Prenet(memory_dim * memory_dim * self.memory_size, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell( self.attention_rnn = AttentionRNNCell(
out_dim=128, out_dim=128,
@ -325,7 +327,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
# learn init values instead of zero init. # learn init values instead of zero init.
self.attention_rnn_init = nn.Embedding(1, 256) self.attention_rnn_init = nn.Embedding(1, 256)
self.memory_init = nn.Embedding(1, r * memory_dim) self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
self.decoder_rnn_inits = nn.Embedding(2, 256) self.decoder_rnn_inits = nn.Embedding(2, 256)
self.stopnet = StopNet(256 + memory_dim * r) self.stopnet = StopNet(256 + memory_dim * r)
# self.init_layers() # self.init_layers()
@ -359,6 +361,7 @@ class Decoder(nn.Module):
T = inputs.size(1) T = inputs.size(1)
# go frame as zeros matrix # go frame as zeros matrix
initial_memory = self.memory_init(inputs.data.new_zeros(B).long()) initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
# decoder states # decoder states
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long()) attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
decoder_rnn_hiddens = [ decoder_rnn_hiddens = [
@ -404,9 +407,14 @@ class Decoder(nn.Module):
while True: while True:
if t > 0: if t > 0:
if memory is None: if memory is None:
memory_input = outputs[-1] new_memory = outputs[-1]
else: else:
memory_input = memory[t - 1] new_memory = memory[t - 1]
# Queuing if memory size defined else use previous prediction only.
if self.memory_size > 0:
memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1)
else:
memory_input = new_memory
# Prenet # Prenet
processed_memory = self.prenet(memory_input) processed_memory = self.prenet(memory_input)
# Attention RNN # Attention RNN