From 90f0cd640bee768e4d210c4ee977c897ccd6b7a0 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 12 Feb 2019 15:27:42 +0100 Subject: [PATCH] memoru queueing --- layers/tacotron.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 3e4faca0..74738810 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -298,16 +298,18 @@ class Decoder(nn.Module): in_features (int): input vector (encoder output) sample size. memory_dim (int): memory vector (prev. time-step output) sample size. r (int): number of outputs per time step. + memory_size (int): size of the past window. if <= 0 memory_size = r """ - def __init__(self, in_features, memory_dim, r, attn_windowing): + def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing): super(Decoder, self).__init__() self.r = r self.in_features = in_features self.max_decoder_steps = 500 + self.memory_size = memory_size if memory_size > 0 else r self.memory_dim = memory_dim # memory -> |Prenet| -> processed_memory - self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) + self.prenet = Prenet(memory_dim * memory_dim * self.memory_size, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State self.attention_rnn = AttentionRNNCell( out_dim=128, @@ -325,7 +327,7 @@ class Decoder(nn.Module): self.proj_to_mel = nn.Linear(256, memory_dim * r) # learn init values instead of zero init. self.attention_rnn_init = nn.Embedding(1, 256) - self.memory_init = nn.Embedding(1, r * memory_dim) + self.memory_init = nn.Embedding(1, self.memory_size * memory_dim) self.decoder_rnn_inits = nn.Embedding(2, 256) self.stopnet = StopNet(256 + memory_dim * r) # self.init_layers() @@ -359,6 +361,7 @@ class Decoder(nn.Module): T = inputs.size(1) # go frame as zeros matrix initial_memory = self.memory_init(inputs.data.new_zeros(B).long()) + # decoder states attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long()) decoder_rnn_hiddens = [ @@ -404,9 +407,14 @@ class Decoder(nn.Module): while True: if t > 0: if memory is None: - memory_input = outputs[-1] + new_memory = outputs[-1] else: - memory_input = memory[t - 1] + new_memory = memory[t - 1] + # Queuing if memory size defined else use previous prediction only. + if self.memory_size > 0: + memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1) + else: + memory_input = new_memory # Prenet processed_memory = self.prenet(memory_input) # Attention RNN